proto: add generic Size
Change-Id: I4ed123f4a9747fb4aba392bc5b9608d294bacc4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/169697
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go
index ac727e8..2f73872 100644
--- a/internal/cmd/generate-types/main.go
+++ b/internal/cmd/generate-types/main.go
@@ -42,6 +42,7 @@
writeSource("internal/prototype/protofile_list_gen.go", generateListTypes())
writeSource("proto/decode_gen.go", generateProtoDecode())
writeSource("proto/encode_gen.go", generateProtoEncode())
+ writeSource("proto/size_gen.go", generateProtoSize())
}
// chdirRoot changes the working directory to the repository root.
diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go
index dea9581..7d688cd 100644
--- a/internal/cmd/generate-types/proto.go
+++ b/internal/cmd/generate-types/proto.go
@@ -269,3 +269,30 @@
return b, nil
}
`))
+
+func generateProtoSize() string {
+ return mustExecute(protoSizeTemplate, ProtoKinds)
+}
+
+var protoSizeTemplate = template.Must(template.New("").Parse(`
+func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
+ switch kind {
+ {{- range .}}
+ case {{.Expr}}:
+ {{if (eq .Name "Message") -}}
+ return wire.SizeBytes(sizeMessage(v.Message()))
+ {{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
+ return wire.Size{{.WireType}}()
+ {{- else if (eq .WireType "Bytes") -}}
+ return wire.Size{{.WireType}}(len({{.FromValue}}))
+ {{- else if (eq .WireType "Group") -}}
+ return wire.Size{{.WireType}}(num, sizeMessage(v.Message()))
+ {{- else -}}
+ return wire.Size{{.WireType}}({{.FromValue}})
+ {{- end}}
+ {{- end}}
+ default:
+ return 0
+ }
+}
+`))
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 8dc218b..feb4ac6 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -10,6 +10,7 @@
"testing"
protoV1 "github.com/golang/protobuf/proto"
+ "github.com/golang/protobuf/v2/encoding/textpb"
"github.com/golang/protobuf/v2/internal/encoding/pack"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
@@ -32,7 +33,7 @@
wire := append(([]byte)(nil), test.wire...)
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(wire, got); err != nil {
- t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
@@ -43,7 +44,7 @@
}
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
- t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
}
@@ -901,3 +902,8 @@
}
}
}
+
+func marshalText(m proto.Message) string {
+ b, _ := textpb.Marshal(m)
+ return string(b)
+}
diff --git a/proto/encode_test.go b/proto/encode_test.go
index 4c5034d..d467b74 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -7,7 +7,6 @@
"testing"
protoV1 "github.com/golang/protobuf/proto"
- //_ "github.com/golang/protobuf/v2/internal/legacy"
"github.com/golang/protobuf/v2/proto"
"github.com/google/go-cmp/cmp"
)
@@ -18,7 +17,12 @@
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
wire, err := proto.Marshal(want)
if err != nil {
- t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
+ }
+
+ size := proto.Size(want)
+ if size != len(wire) {
+ t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
}
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
@@ -41,12 +45,12 @@
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
if err != nil {
- t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
if err != nil {
- t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
if !bytes.Equal(wire, wire2) {
@@ -55,12 +59,12 @@
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(wire, got); err != nil {
- t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
- t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
+ t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
}
diff --git a/proto/size.go b/proto/size.go
new file mode 100644
index 0000000..8c9263f
--- /dev/null
+++ b/proto/size.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+ "fmt"
+
+ "github.com/golang/protobuf/v2/internal/encoding/wire"
+ "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// Size returns the size in bytes of the wire-format encoding of m.
+func Size(m Message) int {
+ return sizeMessage(m.ProtoReflect())
+}
+
+func sizeMessage(m protoreflect.Message) (size int) {
+ fields := m.Type().Fields()
+ knownFields := m.KnownFields()
+ m.KnownFields().Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
+ field := fields.ByNumber(num)
+ if field == nil {
+ field = knownFields.ExtensionTypes().ByNumber(num)
+ if field == nil {
+ panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Type().FullName()))
+ }
+ }
+ size += sizeField(field, value)
+ return true
+ })
+ m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
+ size += len(raw)
+ return true
+ })
+ return size
+}
+
+func sizeField(field protoreflect.FieldDescriptor, value protoreflect.Value) (size int) {
+ num := field.Number()
+ kind := field.Kind()
+ switch {
+ case field.Cardinality() != protoreflect.Repeated:
+ return wire.SizeTag(num) + sizeSingular(num, kind, value)
+ case field.IsMap():
+ return sizeMap(num, kind, field.MessageType(), value.Map())
+ case field.IsPacked():
+ return sizePacked(num, kind, value.List())
+ default:
+ return sizeList(num, kind, value.List())
+ }
+}
+
+func sizeMap(num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) (size int) {
+ keyf := mdesc.Fields().ByNumber(1)
+ valf := mdesc.Fields().ByNumber(2)
+ mapv.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool {
+ size += wire.SizeTag(num)
+ size += wire.SizeBytes(sizeField(keyf, key.Value()) + sizeField(valf, value))
+ return true
+ })
+ return size
+}
+
+func sizePacked(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
+ content := 0
+ for i, llen := 0, list.Len(); i < llen; i++ {
+ content += sizeSingular(num, kind, list.Get(i))
+ }
+ return wire.SizeTag(num) + wire.SizeBytes(content)
+}
+
+func sizeList(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
+ for i, llen := 0, list.Len(); i < llen; i++ {
+ size += wire.SizeTag(num) + sizeSingular(num, kind, list.Get(i))
+ }
+ return size
+}
diff --git a/proto/size_gen.go b/proto/size_gen.go
new file mode 100644
index 0000000..d71c7c7
--- /dev/null
+++ b/proto/size_gen.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+// Code generated by generate-types. DO NOT EDIT.
+
+package proto
+
+import (
+ "github.com/golang/protobuf/v2/internal/encoding/wire"
+ "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
+ switch kind {
+ case protoreflect.BoolKind:
+ return wire.SizeVarint(wire.EncodeBool(v.Bool()))
+ case protoreflect.EnumKind:
+ return wire.SizeVarint(uint64(v.Enum()))
+ case protoreflect.Int32Kind:
+ return wire.SizeVarint(uint64(int32(v.Int())))
+ case protoreflect.Sint32Kind:
+ return wire.SizeVarint(wire.EncodeZigZag(int64(int32(v.Int()))))
+ case protoreflect.Uint32Kind:
+ return wire.SizeVarint(uint64(uint32(v.Uint())))
+ case protoreflect.Int64Kind:
+ return wire.SizeVarint(uint64(v.Int()))
+ case protoreflect.Sint64Kind:
+ return wire.SizeVarint(wire.EncodeZigZag(v.Int()))
+ case protoreflect.Uint64Kind:
+ return wire.SizeVarint(v.Uint())
+ case protoreflect.Sfixed32Kind:
+ return wire.SizeFixed32()
+ case protoreflect.Fixed32Kind:
+ return wire.SizeFixed32()
+ case protoreflect.FloatKind:
+ return wire.SizeFixed32()
+ case protoreflect.Sfixed64Kind:
+ return wire.SizeFixed64()
+ case protoreflect.Fixed64Kind:
+ return wire.SizeFixed64()
+ case protoreflect.DoubleKind:
+ return wire.SizeFixed64()
+ case protoreflect.StringKind:
+ return wire.SizeBytes(len([]byte(v.String())))
+ case protoreflect.BytesKind:
+ return wire.SizeBytes(len(v.Bytes()))
+ case protoreflect.MessageKind:
+ return wire.SizeBytes(sizeMessage(v.Message()))
+ case protoreflect.GroupKind:
+ return wire.SizeGroup(num, sizeMessage(v.Message()))
+ default:
+ return 0
+ }
+}