blob: f329357c253c2d7e503f35da2e7c2971ef597f6d [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protocmp
import (
"bytes"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/reflect/protoreflect"
)
// This implements a custom text marshaler similar to the prototext format.
// We don't use the prototext marshaler so that we can:
// • have finer grain control over the ordering of fields
// • marshal maps with a more aesthetically pleasant output
//
// TODO: If the prototext format gains a map-specific syntax, consider just
// using the prototext marshaler instead.
func appendValue(b []byte, v interface{}) []byte {
switch v := v.(type) {
case bool, int32, int64, uint32, uint64, float32, float64:
return append(b, fmt.Sprint(v)...)
case string:
return append(b, strconv.Quote(string(v))...)
case []byte:
return append(b, strconv.Quote(string(v))...)
case Enum:
return append(b, v.String()...)
case Message:
return appendMessage(b, v)
case protoreflect.RawFields:
return appendValue(b, transformRawFields(v))
default:
switch v := reflect.ValueOf(v); v.Kind() {
case reflect.Slice:
return appendList(b, v)
case reflect.Map:
return appendMap(b, v)
default:
panic(fmt.Sprintf("invalid type: %v", v.Type()))
}
}
}
func appendMessage(b []byte, m Message) []byte {
var knownKeys, extensionKeys, unknownKeys []string
for k := range m {
switch {
case protoreflect.Name(k).IsValid():
knownKeys = append(knownKeys, k)
case strings.HasPrefix(k, "[") && strings.HasSuffix(k, "]"):
extensionKeys = append(extensionKeys, k)
case len(strings.Trim(k, "0123456789")) == 0:
unknownKeys = append(unknownKeys, k)
}
}
sort.Slice(knownKeys, func(i, j int) bool {
fdi := m.Descriptor().Fields().ByName(protoreflect.Name(knownKeys[i]))
fdj := m.Descriptor().Fields().ByName(protoreflect.Name(knownKeys[j]))
return fdi.Index() < fdj.Index()
})
sort.Slice(extensionKeys, func(i, j int) bool {
return extensionKeys[i] < extensionKeys[j]
})
sort.Slice(unknownKeys, func(i, j int) bool {
ni, _ := strconv.Atoi(unknownKeys[i])
nj, _ := strconv.Atoi(unknownKeys[j])
return ni < nj
})
ks := append(append(append([]string(nil), knownKeys...), extensionKeys...), unknownKeys...)
b = append(b, '{')
for _, k := range ks {
b = append(b, k...)
b = append(b, ':')
b = appendValue(b, m[k])
b = append(b, delim()...)
}
b = bytes.TrimRight(b, delim())
b = append(b, '}')
return b
}
func appendList(b []byte, v reflect.Value) []byte {
b = append(b, '[')
for i := 0; i < v.Len(); i++ {
b = appendValue(b, v.Index(i).Interface())
b = append(b, delim()...)
}
b = bytes.TrimRight(b, delim())
b = append(b, ']')
return b
}
func appendMap(b []byte, v reflect.Value) []byte {
ks := v.MapKeys()
sort.Slice(ks, func(i, j int) bool {
ki, kj := ks[i], ks[j]
switch ki.Kind() {
case reflect.Bool:
return !ki.Bool() && kj.Bool()
case reflect.Int32, reflect.Int64:
return ki.Int() < kj.Int()
case reflect.Uint32, reflect.Uint64:
return ki.Uint() < kj.Uint()
case reflect.String:
return ki.String() < kj.String()
default:
panic(fmt.Sprintf("invalid kind: %v", ki.Kind()))
}
})
b = append(b, '{')
for _, k := range ks {
b = appendValue(b, k.Interface())
b = append(b, ':')
b = appendValue(b, v.MapIndex(k).Interface())
b = append(b, delim()...)
}
b = bytes.TrimRight(b, delim())
b = append(b, '}')
return b
}
func transformRawFields(b protoreflect.RawFields) interface{} {
var vs []interface{}
for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b)
m := wire.ConsumeFieldValue(num, typ, b[n:])
vs = append(vs, transformRawField(typ, b[n:][:m]))
b = b[n+m:]
}
if len(vs) == 1 {
return vs[0]
}
return vs
}
func transformRawField(typ wire.Type, b protoreflect.RawFields) interface{} {
switch typ {
case wire.VarintType:
v, _ := wire.ConsumeVarint(b)
return v
case wire.Fixed32Type:
v, _ := wire.ConsumeFixed32(b)
return v
case wire.Fixed64Type:
v, _ := wire.ConsumeFixed64(b)
return v
case wire.BytesType:
v, _ := wire.ConsumeBytes(b)
return v
case wire.StartGroupType:
v := Message{}
for {
num2, typ2, n := wire.ConsumeTag(b)
if typ2 == wire.EndGroupType {
return v
}
m := wire.ConsumeFieldValue(num2, typ2, b[n:])
s := strconv.Itoa(int(num2))
b2, _ := v[s].(protoreflect.RawFields)
v[s] = append(b2, b[:n+m]...)
b = b[n+m:]
}
default:
panic(fmt.Sprintf("invalid type: %v", typ))
}
}
func delim() string {
// Deliberately introduce instability into the message string to
// discourage users from depending on it.
if detrand.Bool() {
return " "
}
return ", "
}