blob: 02f07d480d048aa6cc2630e8445cc37f1fff4cdb [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style.
// license that can be found in the LICENSE file.
package proto_test
import (
"fmt"
"reflect"
"testing"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
testpb "google.golang.org/protobuf/internal/testprotos/test"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
)
func TestDecode(t *testing.T) {
for _, test := range testValidMessages {
if len(test.decodeTo) == 0 {
t.Errorf("%v: no test message types", test.desc)
}
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
opts := proto.UnmarshalOptions{
AllowPartial: test.partial,
}
wire := append(([]byte)(nil), test.wire...)
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
if err := opts.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
// Aliasing check: Modifying the original wire bytes shouldn't
// affect the unmarshaled message.
for i := range wire {
wire[i] = 0
}
if !proto.Equal(got, want) && got.ProtoReflect().IsValid() && want.ProtoReflect().IsValid() {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
}
}
}
func TestDecodeRequiredFieldChecks(t *testing.T) {
for _, test := range testValidMessages {
if !test.partial {
continue
}
for _, m := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(test.wire, got); err == nil {
t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", marshalText(got))
}
})
}
}
}
func TestDecodeInvalidMessages(t *testing.T) {
for _, test := range testInvalidMessages {
if len(test.decodeTo) == 0 {
t.Errorf("%v: no test message types", test.desc)
}
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
opts := proto.UnmarshalOptions{
AllowPartial: test.partial,
}
got := want.ProtoReflect().New().Interface()
if err := opts.Unmarshal(test.wire, got); err == nil {
t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, marshalText(got))
}
})
}
}
}
func TestDecodeZeroLengthBytes(t *testing.T) {
// Verify that proto3 bytes fields don't give the mistaken
// impression that they preserve presence.
wire := pack.Message{
pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
}.Marshal()
m := &test3pb.TestAllTypes{}
if err := proto.Unmarshal(wire, m); err != nil {
t.Fatal(err)
}
if m.OptionalBytes != nil {
t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
}
}
func TestDecodeOneofNilWrapper(t *testing.T) {
wire := pack.Message{
pack.Tag{111, pack.VarintType}, pack.Varint(1111),
}.Marshal()
m := &testpb.TestAllTypes{OneofField: (*testpb.TestAllTypes_OneofUint32)(nil)}
if err := proto.Unmarshal(wire, m); err != nil {
t.Fatal(err)
}
if got := m.GetOneofUint32(); got != 1111 {
t.Errorf("GetOneofUint32() = %v, want %v", got, 1111)
}
}
func TestDecodeEmptyBytes(t *testing.T) {
// There's really nothing wrong with a nil entry in a [][]byte,
// but we take care to produce non-nil []bytes for zero-length
// byte strings, so test for it.
m := &testpb.TestAllTypes{}
b := pack.Message{
pack.Tag{45, pack.BytesType}, pack.Bytes(nil),
}.Marshal()
if err := proto.Unmarshal(b, m); err != nil {
t.Fatal(err)
}
if m.RepeatedBytes[0] == nil {
t.Errorf("unmarshaling repeated bytes field containing zero-length value: Got nil bytes, want non-nil")
}
}
func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)
}
return m
}
type buildOpt func(proto.Message)
func unknown(raw protoreflect.RawFields) buildOpt {
return func(m proto.Message) {
m.ProtoReflect().SetUnknown(raw)
}
}
func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
return func(m proto.Message) {
proto.SetExtension(m, desc, value)
}
}
func marshalText(m proto.Message) string {
if m == nil {
return "<nil>\n"
}
b, _ := prototext.MarshalOptions{
AllowPartial: true,
EmitUnknown: true,
Indent: "\t",
}.Marshal(m)
return string(b)
}