all: refactor extensions, add proto.GetExtension etc.
Change protoiface.ExtensionDescV1 to implement protoreflect.ExtensionType.
ExtensionDescV1's Name field conflicts with the Descriptor Name method,
so change the protoreflect.{Message,Enum,Extension}Type types to no
longer implement the corresponding Descriptor interface. This also leads
to a clearer distinction between the two types.
Introduce a protoreflect.ExtensionTypeDescriptor type which bridges
between ExtensionType and ExtensionDescriptor.
Add extension accessor functions to the proto package:
proto.{Has,Clear,Get,Set}Extension. These functions take a
protoreflect.ExtensionType parameter, which allows writing the
same function call using either the old or new API:
proto.GetExtension(message, somepb.E_ExtensionFoo)
Fixes golang/protobuf#908
Change-Id: Ibc65d12a46666297849114fd3aefbc4a597d9f08
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189199
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go-grpc/testdata/go.mod b/cmd/protoc-gen-go-grpc/testdata/go.mod
index 67f6a1c..2b2c716 100644
--- a/cmd/protoc-gen-go-grpc/testdata/go.mod
+++ b/cmd/protoc-gen-go-grpc/testdata/go.mod
@@ -1,7 +1,7 @@
module google.golang.org/protobuf/cmd/protoc-gen-go-grpc/testdata
require (
- github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4
+ github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0
google.golang.org/grpc v1.19.0
google.golang.org/protobuf v1.0.0
)
diff --git a/cmd/protoc-gen-go/testdata/go.mod b/cmd/protoc-gen-go/testdata/go.mod
index fbf038f..2413632 100644
--- a/cmd/protoc-gen-go/testdata/go.mod
+++ b/cmd/protoc-gen-go/testdata/go.mod
@@ -1,7 +1,7 @@
module google.golang.org/protobuf/cmd/protoc-gen-go/testdata
require (
- github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4
+ github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0
google.golang.org/protobuf v1.0.0
)
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index 7c0da23..1d22d3a 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -182,7 +182,9 @@
if err != nil && err != protoregistry.NotFound {
return errors.New("unable to resolve [%v]: %v", extName, err)
}
- fd = extType
+ if extType != nil {
+ fd = extType.Descriptor()
+ }
} else {
// The name can either be the JSON name or the proto field name.
fd = fieldDescs.ByJSONName(name)
diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go
index 2385dc4..9c4f93e 100644
--- a/encoding/protojson/decode_test.go
+++ b/encoding/protojson/decode_test.go
@@ -1200,10 +1200,10 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_OptExtBool, true)
- setExtension(m, pb2.E_OptExtString, "extension field")
- setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -1225,9 +1225,9 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
- setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+ proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1250,10 +1250,10 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
- setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -1282,9 +1282,9 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1323,13 +1323,13 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
OptString: proto.String("just a regular extension"),
})
return m
@@ -1345,7 +1345,7 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.FakeMessageSet{}
- setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
return m
@@ -1371,7 +1371,7 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("another not a messageset extension"),
})
return m
@@ -2402,7 +2402,7 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_OptExtNested, &pb2.Nested{})
+ proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{})
return m
}(),
}, {
diff --git a/encoding/protojson/encode_test.go b/encoding/protojson/encode_test.go
index 2aa61dd..5cacf9e 100644
--- a/encoding/protojson/encode_test.go
+++ b/encoding/protojson/encode_test.go
@@ -16,7 +16,6 @@
pimpl "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
preg "google.golang.org/protobuf/reflect/protoregistry"
- "google.golang.org/protobuf/runtime/protoiface"
"google.golang.org/protobuf/encoding/testprotos/pb2"
"google.golang.org/protobuf/encoding/testprotos/pb3"
@@ -29,11 +28,6 @@
"google.golang.org/protobuf/types/known/wrapperspb"
)
-// TODO: Replace this with proto.SetExtension.
-func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) {
- m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val))
-}
-
func TestMarshal(t *testing.T) {
tests := []struct {
desc string
@@ -886,10 +880,10 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_OptExtBool, true)
- setExtension(m, pb2.E_OptExtString, "extension field")
- setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -915,9 +909,9 @@
desc: "extensions of repeated fields",
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
- setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+ proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -950,10 +944,10 @@
desc: "extensions of non-repeated fields in another message",
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
- setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -980,9 +974,9 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1018,13 +1012,13 @@
desc: "MessageSet",
input: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
OptString: proto.String("just a regular extension"),
})
return m
@@ -1045,7 +1039,7 @@
desc: "not real MessageSet 1",
input: func() proto.Message {
m := &pb2.FakeMessageSet{}
- setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
return m
@@ -1060,7 +1054,7 @@
desc: "not real MessageSet 2",
input: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("another not a messageset extension"),
})
return m
diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go
index e744536..77833d8 100644
--- a/encoding/protojson/well_known_types.go
+++ b/encoding/protojson/well_known_types.go
@@ -189,7 +189,7 @@
// If type of value has custom JSON encoding, marshal out a field "value"
// with corresponding custom JSON encoding of the embedded message as a
// field.
- if isCustomType(emt.FullName()) {
+ if isCustomType(emt.Descriptor().FullName()) {
o.encoder.WriteName("value")
return o.marshalCustomType(em)
}
@@ -235,7 +235,7 @@
// Create new message for the embedded message type and unmarshal into it.
em := emt.New()
- if isCustomType(emt.FullName()) {
+ if isCustomType(emt.Descriptor().FullName()) {
// If embedded message is a custom type, unmarshal the JSON "value" field
// into it.
if err := o.unmarshalAnyValue(em); err != nil {
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index 4f384a0..06388c8 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -126,7 +126,9 @@
if err != nil && err != protoregistry.NotFound {
return errors.New("unable to resolve [%v]: %v", extName, err)
}
- fd = xt
+ if xt != nil {
+ fd = xt.Descriptor()
+ }
}
if fd == nil {
diff --git a/encoding/prototext/decode_test.go b/encoding/prototext/decode_test.go
index 20ce133..588d9ee 100644
--- a/encoding/prototext/decode_test.go
+++ b/encoding/prototext/decode_test.go
@@ -1171,10 +1171,10 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_OptExtBool, true)
- setExtension(m, pb2.E_OptExtString, "extension field")
- setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -1207,9 +1207,9 @@
`,
wantMessage: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
- setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+ proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1231,10 +1231,10 @@
`,
wantMessage: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
- setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -1269,9 +1269,9 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1299,13 +1299,13 @@
`,
wantMessage: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
OptString: proto.String("just a regular extension"),
})
return m
@@ -1321,7 +1321,7 @@
`,
wantMessage: func() proto.Message {
m := &pb2.FakeMessageSet{}
- setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
return m
@@ -1346,7 +1346,7 @@
}`,
wantMessage: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("another not a messageset extension"),
})
return m
diff --git a/encoding/prototext/encode_test.go b/encoding/prototext/encode_test.go
index c29169b..493a229 100644
--- a/encoding/prototext/encode_test.go
+++ b/encoding/prototext/encode_test.go
@@ -16,7 +16,6 @@
pimpl "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
preg "google.golang.org/protobuf/reflect/protoregistry"
- "google.golang.org/protobuf/runtime/protoiface"
"google.golang.org/protobuf/encoding/testprotos/pb2"
"google.golang.org/protobuf/encoding/testprotos/pb3"
@@ -28,11 +27,6 @@
detrand.Disable()
}
-// TODO: Use proto.SetExtension when available.
-func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) {
- m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val))
-}
-
func TestMarshal(t *testing.T) {
tests := []struct {
desc string
@@ -905,10 +899,10 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_OptExtBool, true)
- setExtension(m, pb2.E_OptExtString, "extension field")
- setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -933,7 +927,7 @@
desc: "extension field contains invalid UTF-8",
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_OptExtString, "abc\xff")
+ proto.SetExtension(m, pb2.E_OptExtString, "abc\xff")
return m
}(),
wantErr: true,
@@ -941,10 +935,10 @@
desc: "extension partial returns error",
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{
+ proto.SetExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{
OptString: proto.String("partial1"),
})
- setExtension(m, pb2.E_ExtensionsContainer_OptExtPartial, &pb2.PartialRequired{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtPartial, &pb2.PartialRequired{
OptString: proto.String("partial2"),
})
return m
@@ -962,7 +956,7 @@
mo: prototext.MarshalOptions{AllowPartial: true},
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{
+ proto.SetExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{
OptString: proto.String("partial1"),
})
return m
@@ -975,9 +969,9 @@
desc: "extensions of repeated fields",
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
- setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+ proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1003,10 +997,10 @@
desc: "extensions of non-repeated fields in another message",
input: func() proto.Message {
m := &pb2.Extensions{}
- setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
- setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
- setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
OptString: proto.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: proto.String("another nested in an extension"),
@@ -1032,9 +1026,9 @@
OptBool: proto.Bool(true),
OptInt32: proto.Int32(42),
}
- setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
- setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+ proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: proto.String("one")},
&pb2.Nested{OptString: proto.String("two")},
&pb2.Nested{OptString: proto.String("three")},
@@ -1063,13 +1057,13 @@
desc: "MessageSet",
input: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
- setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+ proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
OptString: proto.String("just a regular extension"),
})
return m
@@ -1089,7 +1083,7 @@
desc: "not real MessageSet 1",
input: func() proto.Message {
m := &pb2.FakeMessageSet{}
- setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("not a messageset extension"),
})
return m
@@ -1103,7 +1097,7 @@
desc: "not real MessageSet 2",
input: func() proto.Message {
m := &pb2.MessageSet{}
- setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: proto.String("another not a messageset extension"),
})
return m
diff --git a/go.mod b/go.mod
index d9c1034..c29b074 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,6 @@
go 1.9
require (
- github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4
+ github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0
github.com/google/go-cmp v0.3.0
)
diff --git a/go.sum b/go.sum
index 274f450..de16d81 100644
--- a/go.sum
+++ b/go.sum
@@ -6,8 +6,8 @@
github.com/golang/protobuf v1.2.1-0.20190605195750-76c9e09470ba/go.mod h1:S1YIJXvYHGRCG2UmZsOcElkAYfvZLg2sDRr9+Xu8JXU=
github.com/golang/protobuf v1.2.1-0.20190617175902-f94016f5239f/go.mod h1:G+HpKX7pYZAVkElkAWZkr08MToW6pTp/vs+E9osFfbg=
github.com/golang/protobuf v1.2.1-0.20190620192300-1ee46dfd80dd/go.mod h1:+CMAsi9jpYf/wAltLUKlg++CWXqxCJyD8iLDbQONsJs=
-github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4 h1:Hj8cGYPgLw3MR0AGL0GFObh4pq8i31QOWWMCE0KY9z4=
-github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4/go.mod h1:tDQPRlaHYu9yt1wPgdx85inRiLvUCuJZXsYjC0mwc1c=
+github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0 h1:a3hJDGxxWRbPxfOMiV6aG8pb0I+8RdgICRdXjXjiKzY=
+github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0/go.mod h1:tDQPRlaHYu9yt1wPgdx85inRiLvUCuJZXsYjC0mwc1c=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
google.golang.org/protobuf v0.0.0-20190514172829-e89e6244e0e8/go.mod h1:791zQGC15vDqjpmPRn1uGPu5oHy/Jzw/Q1n5JsgIIcY=
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
index 7d05dee..b1c6db5 100644
--- a/internal/encoding/messageset/messageset.go
+++ b/internal/encoding/messageset/messageset.go
@@ -65,7 +65,7 @@
if err != nil {
return nil, err
}
- if !IsMessageSetExtension(xt) {
+ if !IsMessageSetExtension(xt.Descriptor()) {
return nil, preg.NotFound
}
return xt, nil
diff --git a/internal/filetype/build.go b/internal/filetype/build.go
index f30f98c..15c8f27 100644
--- a/internal/filetype/build.go
+++ b/internal/filetype/build.go
@@ -7,11 +7,9 @@
package filetype
import (
- "fmt"
"reflect"
"sync"
- "google.golang.org/protobuf/internal/descfmt"
"google.golang.org/protobuf/internal/descopts"
fdesc "google.golang.org/protobuf/internal/filedesc"
pimpl "google.golang.org/protobuf/internal/impl"
@@ -358,8 +356,7 @@
t.lazyInit()
return t.goType
}
-func (t *Extension) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor }
-func (t *Extension) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, t) }
+func (t *Extension) Descriptor() pref.ExtensionTypeDescriptor { return (*extDesc)(t) }
// ProtoLegacyExtensionDesc is a pseudo-internal API for allowing the v1 code
// to be able to retrieve a v1 ExtensionDesc.
@@ -379,3 +376,8 @@
})
return t.conv
}
+
+type extDesc Extension
+
+func (t *extDesc) Type() pref.ExtensionType { return (*Extension)(t) }
+func (t *extDesc) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor }
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index 3a23bb0..7e430ee 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -29,25 +29,26 @@
return e
}
+ xd := xt.Descriptor()
var wiretag uint64
- if !xt.IsPacked() {
- wiretag = wire.EncodeTag(xt.Number(), wireTypes[xt.Kind()])
+ if !xd.IsPacked() {
+ wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
} else {
- wiretag = wire.EncodeTag(xt.Number(), wire.BytesType)
+ wiretag = wire.EncodeTag(xd.Number(), wire.BytesType)
}
e = &extensionFieldInfo{
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
- funcs: encoderFuncsForValue(xt, xt.GoType()),
+ funcs: encoderFuncsForValue(xd, xt.GoType()),
}
// Does the unmarshal function need a value passed to it?
// This is true for composite types, where we pass in a message, list, or map to fill in,
// and for enums, where we pass in a prototype value to specify the concrete enum type.
- switch xt.Kind() {
+ switch xd.Kind() {
case pref.MessageKind, pref.GroupKind, pref.EnumKind:
e.unmarshalNeedsValue = true
default:
- if xt.Cardinality() == pref.Repeated {
+ if xd.Cardinality() == pref.Repeated {
e.unmarshalNeedsValue = true
}
}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index e43c812..7014861 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -44,8 +44,9 @@
mi.extensionOffset = si.extensionOffset
mi.coderFields = make(map[wire.Number]*coderFieldInfo)
- for i := 0; i < mi.PBType.Fields().Len(); i++ {
- fd := mi.PBType.Fields().Get(i)
+ fields := mi.PBType.Descriptor().Fields()
+ for i := 0; i < fields.Len(); i++ {
+ fd := fields.Get(i)
fs := si.fieldsByNumber[fd.Number()]
if fd.ContainingOneof() != nil {
@@ -81,7 +82,7 @@
}
if messageset.IsMessageSet(mi.PBType.Descriptor()) {
if !mi.extensionOffset.IsValid() {
- panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.FullName()))
+ panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.Descriptor().FullName()))
}
cf := &coderFieldInfo{
num: messageset.FieldItem,
@@ -113,7 +114,7 @@
mi.denseCoderFields[cf.num] = cf
}
- mi.needsInitCheck = needsInitCheck(mi.PBType)
+ mi.needsInitCheck = needsInitCheck(mi.PBType.Descriptor())
mi.methods = piface.Methods{
Flags: piface.SupportMarshalDeterministic | piface.SupportUnmarshalDiscardUnknown,
MarshalAppend: mi.marshalAppend,
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 4e4c7f3..4821852 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -138,7 +138,7 @@
xt := x.GetType()
if xt == nil {
var err error
- xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.FullName(), num)
+ xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.Descriptor().FullName(), num)
if err != nil {
if err == preg.NotFound {
return 0, errUnknown
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
index ca00012..079afe0 100644
--- a/internal/impl/isinit.go
+++ b/internal/impl/isinit.go
@@ -29,7 +29,7 @@
if p.IsNil() {
for _, f := range mi.orderedCoderFields {
if f.isRequired {
- return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName()))
+ return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
}
}
return nil
@@ -47,7 +47,7 @@
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
if f.isRequired {
- return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName()))
+ return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
}
continue
}
diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go
index aaf8fcc..2da4d71 100644
--- a/internal/impl/legacy_extension.go
+++ b/internal/impl/legacy_extension.go
@@ -5,11 +5,9 @@
package impl
import (
- "fmt"
"reflect"
"sync"
- "google.golang.org/protobuf/internal/descfmt"
ptag "google.golang.org/protobuf/internal/encoding/tag"
"google.golang.org/protobuf/internal/filedesc"
pref "google.golang.org/protobuf/reflect/protoreflect"
@@ -62,8 +60,9 @@
}
// Determine the parent type if possible.
+ xd := xt.Descriptor()
var parent piface.MessageV1
- messageName := xt.ContainingMessage().FullName()
+ messageName := xd.ContainingMessage().FullName()
if mt, _ := preg.GlobalTypes.FindMessageByName(messageName); mt != nil {
// Create a new parent message and unwrap it if possible.
mv := mt.New().Interface()
@@ -94,7 +93,7 @@
// Reconstruct the legacy enum full name, which is an odd mixture of the
// proto package name with the Go type name.
var enumName string
- if xt.Kind() == pref.EnumKind {
+ if xd.Kind() == pref.EnumKind {
// Derive Go type name.
t := extType
if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
@@ -105,7 +104,7 @@
// Derive the proto package name.
// For legacy enums, obtain the proto package from the raw descriptor.
var protoPkg string
- if fd := xt.Enum().ParentFile(); fd != nil {
+ if fd := xd.Enum().ParentFile(); fd != nil {
protoPkg = string(fd.Package())
}
if ed, ok := reflect.Zero(t).Interface().(enumV1); ok && protoPkg == "" {
@@ -120,7 +119,7 @@
// Derive the proto file that the extension was declared within.
var filename string
- if fd := xt.ParentFile(); fd != nil {
+ if fd := xd.ParentFile(); fd != nil {
filename = fd.Path()
}
@@ -129,9 +128,9 @@
Type: xt,
ExtendedType: parent,
ExtensionType: reflect.Zero(extType).Interface(),
- Field: int32(xt.Number()),
- Name: string(xt.FullName()),
- Tag: ptag.Marshal(xt, enumName),
+ Field: int32(xd.Number()),
+ Name: string(xd.FullName()),
+ Tag: ptag.Marshal(xd, enumName),
Filename: filename,
}
if d, ok := legacyExtensionDescCache.LoadOrStore(xt, d); ok {
@@ -217,15 +216,16 @@
//
// This is exported for testing purposes.
func LegacyExtensionTypeOf(xd pref.ExtensionDescriptor, t reflect.Type) pref.ExtensionType {
- return &legacyExtensionType{
- ExtensionDescriptor: xd,
- typ: t,
- conv: NewConverter(t, xd),
+ xt := &legacyExtensionType{
+ typ: t,
+ conv: NewConverter(t, xd),
}
+ xt.desc = &extDesc{xd, xt}
+ return xt
}
type legacyExtensionType struct {
- pref.ExtensionDescriptor
+ desc pref.ExtensionTypeDescriptor
typ reflect.Type
conv Converter
}
@@ -239,5 +239,12 @@
func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} {
return x.conv.GoValueOf(v).Interface()
}
-func (x *legacyExtensionType) Descriptor() pref.ExtensionDescriptor { return x.ExtensionDescriptor }
-func (x *legacyExtensionType) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, x) }
+func (x *legacyExtensionType) Descriptor() pref.ExtensionTypeDescriptor { return x.desc }
+
+type extDesc struct {
+ pref.ExtensionDescriptor
+ xt *legacyExtensionType
+}
+
+func (t *extDesc) Type() pref.ExtensionType { return t.xt }
+func (t *extDesc) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor }
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index 70c5603..89cd0bc 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -52,7 +52,7 @@
func init() {
mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
- preg.GlobalFiles.Register(mt.ParentFile())
+ preg.GlobalFiles.Register(mt.Descriptor().ParentFile())
preg.GlobalTypes.Register(mt)
}
@@ -357,19 +357,21 @@
}
for i, xt := range extensionTypes {
var got interface{}
- if !(xt.IsList() || xt.IsMap() || xt.Message() != nil) {
- got = xt.InterfaceOf(m.Get(xt))
+ xd := xt.Descriptor()
+ if !(xd.IsList() || xd.IsMap() || xd.Message() != nil) {
+ got = xt.InterfaceOf(m.Get(xd))
}
want := defaultValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
- t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+ t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
}
}
// All fields should be unpopulated.
for _, xt := range extensionTypes {
- if m.Has(xt) {
- t.Errorf("Message.Has(%d) = true, want false", xt.Number())
+ xd := xt.Descriptor()
+ if m.Has(xd) {
+ t.Errorf("Message.Has(%d) = true, want false", xd.Number())
}
}
@@ -401,11 +403,11 @@
19: &[]*EnumMessages{m2b},
}
for i, xt := range extensionTypes {
- m.Set(xt, xt.ValueOf(setValues[i]))
+ m.Set(xt.Descriptor(), xt.ValueOf(setValues[i]))
}
for i, xt := range extensionTypes[len(extensionTypes)/2:] {
v := extensionTypes[i].ValueOf(setValues[i])
- m.Get(xt).List().Append(v)
+ m.Get(xt.Descriptor()).List().Append(v)
}
// Get the values and check for equality.
@@ -432,24 +434,25 @@
19: &[]*EnumMessages{m2b, m2a},
}
for i, xt := range extensionTypes {
- got := xt.InterfaceOf(m.Get(xt))
+ xd := xt.Descriptor()
+ got := xt.InterfaceOf(m.Get(xd))
want := getValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
- t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+ t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
}
}
// Clear all singular fields and truncate all repeated fields.
for _, xt := range extensionTypes[:len(extensionTypes)/2] {
- m.Clear(xt)
+ m.Clear(xt.Descriptor())
}
for _, xt := range extensionTypes[len(extensionTypes)/2:] {
- m.Get(xt).List().Truncate(0)
+ m.Get(xt.Descriptor()).List().Truncate(0)
}
// Clear all repeated fields.
for _, xt := range extensionTypes[len(extensionTypes)/2:] {
- m.Clear(xt)
+ m.Clear(xt.Descriptor())
}
}
@@ -491,8 +494,6 @@
switch name {
case "ParentFile", "Parent":
// Ignore parents to avoid recursive cycle.
- case "New", "Zero":
- // Ignore constructors.
case "Options":
// Ignore descriptor options since protos are not cmperable.
case "ContainingOneof", "ContainingMessage", "Enum", "Message":
@@ -504,6 +505,8 @@
if !v.IsNil() {
out[name] = v.Interface().(pref.Descriptor).FullName()
}
+ case "Type":
+ // Ignore ExtensionTypeDescriptor.Type method to avoid cycle.
default:
out[name] = m.Call(nil)[0].Interface()
}
@@ -511,6 +514,12 @@
}
return out
}),
+ cmp.Transformer("", func(xt pref.ExtensionType) map[string]interface{} {
+ return map[string]interface{}{
+ "Descriptor": xt.Descriptor(),
+ "GoType": xt.GoType(),
+ }
+ }),
cmp.Transformer("", func(v pref.Value) interface{} {
return v.Interface()
}),
@@ -605,23 +614,23 @@
var (
wantMTA = messageATypes[0]
- wantMDA = messageATypes[0].Fields().ByNumber(1).Message()
+ wantMDA = messageATypes[0].Descriptor().Fields().ByNumber(1).Message()
wantMTB = messageBTypes[0]
- wantMDB = messageBTypes[0].Fields().ByNumber(2).Message()
- wantED = messageATypes[0].Fields().ByNumber(3).Enum()
+ wantMDB = messageBTypes[0].Descriptor().Fields().ByNumber(2).Message()
+ wantED = messageATypes[0].Descriptor().Fields().ByNumber(3).Enum()
)
for _, gotMT := range messageATypes[1:] {
if gotMT != wantMTA {
t.Error("MessageType(MessageA) mismatch")
}
- if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+ if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
t.Error("MessageDescriptor(MessageA) mismatch")
}
- if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+ if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
t.Error("MessageDescriptor(MessageB) mismatch")
}
- if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+ if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
t.Error("EnumDescriptor(Enum) mismatch")
}
}
@@ -629,13 +638,13 @@
if gotMT != wantMTB {
t.Error("MessageType(MessageB) mismatch")
}
- if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+ if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
t.Error("MessageDescriptor(MessageA) mismatch")
}
- if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+ if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
t.Error("MessageDescriptor(MessageB) mismatch")
}
- if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+ if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
t.Error("EnumDescriptor(Enum) mismatch")
}
}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 305e17d..6100663 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -222,8 +222,9 @@
// any discrepancies.
func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
mi.fields = map[pref.FieldNumber]*fieldInfo{}
- for i := 0; i < mi.PBType.Fields().Len(); i++ {
- fd := mi.PBType.Fields().Get(i)
+ md := mi.PBType.Descriptor()
+ for i := 0; i < md.Fields().Len(); i++ {
+ fd := md.Fields().Get(i)
fs := si.fieldsByNumber[fd.Number()]
var fi fieldInfo
switch {
@@ -244,8 +245,8 @@
}
mi.oneofs = map[pref.Name]*oneofInfo{}
- for i := 0; i < mi.PBType.Oneofs().Len(); i++ {
- od := mi.PBType.Oneofs().Get(i)
+ for i := 0; i < md.Oneofs().Len(); i++ {
+ od := md.Oneofs().Get(i)
mi.oneofs[od.Name()] = makeOneofInfo(od, si.oneofsByName[od.Name()], mi.Exporter, si.oneofWrappersByType)
}
}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 699ac2c..b0f1778 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -121,7 +121,7 @@
if m != nil {
for _, x := range *m {
xt := x.GetType()
- if !f(xt, xt.ValueOf(x.GetValue())) {
+ if !f(xt.Descriptor(), xt.ValueOf(x.GetValue())) {
return
}
}
@@ -129,16 +129,17 @@
}
func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
if m != nil {
- _, ok = (*m)[int32(xt.Number())]
+ _, ok = (*m)[int32(xt.Descriptor().Number())]
}
return ok
}
func (m *extensionMap) Clear(xt pref.ExtensionType) {
- delete(*m, int32(xt.Number()))
+ delete(*m, int32(xt.Descriptor().Number()))
}
func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
+ xd := xt.Descriptor()
if m != nil {
- if x, ok := (*m)[int32(xt.Number())]; ok {
+ if x, ok := (*m)[int32(xd.Number())]; ok {
return xt.ValueOf(x.GetValue())
}
}
@@ -151,13 +152,14 @@
var x ExtensionField
x.SetType(xt)
x.SetEagerValue(xt.InterfaceOf(v))
- (*m)[int32(xt.Number())] = x
+ (*m)[int32(xt.Descriptor().Number())] = x
}
func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
- if !isComposite(xt) {
+ xd := xt.Descriptor()
+ if !isComposite(xd) {
panic("invalid Mutable on field with non-composite type")
}
- if x, ok := (*m)[int32(xt.Number())]; ok {
+ if x, ok := (*m)[int32(xd.Number())]; ok {
return xt.ValueOf(x.GetValue())
}
v := xt.New()
@@ -179,14 +181,18 @@
return fi, nil
}
if fd.IsExtension() {
- if fd.ContainingMessage().FullName() != mi.PBType.FullName() {
+ if fd.ContainingMessage().FullName() != mi.PBType.Descriptor().FullName() {
// TODO: Should this be exact containing message descriptor match?
panic("mismatching containing message")
}
- if !mi.PBType.ExtensionRanges().Has(fd.Number()) {
+ if !mi.PBType.Descriptor().ExtensionRanges().Has(fd.Number()) {
panic("invalid extension field")
}
- return nil, fd.(pref.ExtensionType)
+ xtd, ok := fd.(pref.ExtensionTypeDescriptor)
+ if !ok {
+ panic("extension descriptor does not implement ExtensionTypeDescriptor")
+ }
+ return nil, xtd.Type()
}
panic("invalid field descriptor")
}
diff --git a/proto/decode.go b/proto/decode.go
index f147e68..e394243 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -88,7 +88,9 @@
if err != nil && err != protoregistry.NotFound {
return err
}
- fd = extType
+ if extType != nil {
+ fd = extType.Descriptor()
+ }
}
var err error
var valLen int
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 6088eb5..89fc303 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -1680,10 +1680,8 @@
v.Elem().Set(reflect.ValueOf(value))
value = v.Interface()
}
-
return func(m proto.Message) {
- xt := desc.Type
- m.ProtoReflect().Set(xt, xt.ValueOf(value))
+ proto.SetExtension(m, desc, value)
}
}
diff --git a/proto/extension.go b/proto/extension.go
new file mode 100644
index 0000000..2e1c78f
--- /dev/null
+++ b/proto/extension.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+ "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// HasExtension reports whether an extension field is populated.
+func HasExtension(m Message, ext protoreflect.ExtensionType) bool {
+ return m.ProtoReflect().Has(ext.Descriptor())
+}
+
+// ClearExtension clears an extension field such that subsequent
+// HasExtension calls return false.
+func ClearExtension(m Message, ext protoreflect.ExtensionType) {
+ m.ProtoReflect().Clear(ext.Descriptor())
+}
+
+// GetExtension retrieves the value for an extension field.
+//
+// If the field is unpopulated, it returns the default value for
+// scalars and an immutable, empty value for lists, maps, or messages.
+func GetExtension(m Message, ext protoreflect.ExtensionType) interface{} {
+ return ext.InterfaceOf(m.ProtoReflect().Get(ext.Descriptor()))
+}
+
+// SetExtension stores the value of an extension field.
+func SetExtension(m Message, ext protoreflect.ExtensionType, value interface{}) {
+ m.ProtoReflect().Set(ext.Descriptor(), ext.ValueOf(value))
+}
diff --git a/proto/extension_test.go b/proto/extension_test.go
new file mode 100644
index 0000000..ce3a142
--- /dev/null
+++ b/proto/extension_test.go
@@ -0,0 +1,70 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "google.golang.org/protobuf/proto"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/runtime/protoimpl"
+
+ legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
+ testpb "google.golang.org/protobuf/internal/testprotos/test"
+)
+
+func TestExtensionFuncs(t *testing.T) {
+ for _, test := range []struct {
+ message proto.Message
+ ext pref.ExtensionType
+ wantDefault interface{}
+ value interface{}
+ }{
+ {
+ message: &testpb.TestAllExtensions{},
+ ext: testpb.E_OptionalInt32Extension,
+ wantDefault: int32(0),
+ value: int32(1),
+ },
+ {
+ message: &testpb.TestAllExtensions{},
+ ext: testpb.E_RepeatedStringExtension,
+ // TODO: Represent repeated extension fields as []T.
+ // https://github.com/golang/protobuf/issues/901
+ wantDefault: (*[]string)(nil),
+ value: &[]string{"a", "b", "c"},
+ },
+ {
+ message: protoimpl.X.MessageOf(&legacy1pb.Message{}).Interface(),
+ ext: legacy1pb.E_Message_ExtensionOptionalBool,
+ wantDefault: false,
+ value: true,
+ },
+ } {
+ desc := fmt.Sprintf("Extension %v, value %v", test.ext.Descriptor().FullName(), test.value)
+ if proto.HasExtension(test.message, test.ext) {
+ t.Errorf("%v:\nbefore setting extension HasExtension(...) = true, want false", desc)
+ }
+ got := proto.GetExtension(test.message, test.ext)
+ if d := cmp.Diff(test.wantDefault, got); d != "" {
+ t.Errorf("%v:\nbefore setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
+ }
+ proto.SetExtension(test.message, test.ext, test.value)
+ if !proto.HasExtension(test.message, test.ext) {
+ t.Errorf("%v:\nafter setting extension HasExtension(...) = false, want true", desc)
+ }
+ got = proto.GetExtension(test.message, test.ext)
+ if d := cmp.Diff(test.value, got); d != "" {
+ t.Errorf("%v:\nafter setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
+ }
+ proto.ClearExtension(test.message, test.ext)
+ if proto.HasExtension(test.message, test.ext) {
+ t.Errorf("%v:\nafter clearing extension HasExtension(...) = true, want false", desc)
+ }
+
+ }
+}
diff --git a/proto/merge_test.go b/proto/merge_test.go
index 4fa161e..a0ec571 100644
--- a/proto/merge_test.go
+++ b/proto/merge_test.go
@@ -274,65 +274,41 @@
desc: "merge extension fields",
dst: func() proto.Message {
m := new(testpb.TestAllExtensions)
- m.ProtoReflect().Set(
- testpb.E_OptionalInt32Extension.Type,
- testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
- )
- m.ProtoReflect().Set(
- testpb.E_OptionalNestedMessageExtension.Type,
- testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
+ proto.SetExtension(m, testpb.E_OptionalInt32Extension.Type, int32(32))
+ proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension.Type,
+ &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(50),
- }),
+ },
)
- m.ProtoReflect().Set(
- testpb.E_RepeatedFixed32Extension.Type,
- testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3}),
- )
+ proto.SetExtension(m, testpb.E_RepeatedFixed32Extension.Type, &[]uint32{1, 2, 3})
return m
}(),
src: func() proto.Message {
m := new(testpb.TestAllExtensions)
- m.ProtoReflect().Set(
- testpb.E_OptionalInt64Extension.Type,
- testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
- )
- m.ProtoReflect().Set(
- testpb.E_OptionalNestedMessageExtension.Type,
- testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
+ proto.SetExtension(m, testpb.E_OptionalInt64Extension.Type, int64(64))
+ proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension.Type,
+ &testpb.TestAllTypes_NestedMessage{
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
- }),
+ },
)
- m.ProtoReflect().Set(
- testpb.E_RepeatedFixed32Extension.Type,
- testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{4, 5, 6}),
- )
+ proto.SetExtension(m, testpb.E_RepeatedFixed32Extension.Type, &[]uint32{4, 5, 6})
return m
}(),
want: func() proto.Message {
m := new(testpb.TestAllExtensions)
- m.ProtoReflect().Set(
- testpb.E_OptionalInt32Extension.Type,
- testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
- )
- m.ProtoReflect().Set(
- testpb.E_OptionalInt64Extension.Type,
- testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
- )
- m.ProtoReflect().Set(
- testpb.E_OptionalNestedMessageExtension.Type,
- testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
+ proto.SetExtension(m, testpb.E_OptionalInt32Extension.Type, int32(32))
+ proto.SetExtension(m, testpb.E_OptionalInt64Extension.Type, int64(64))
+ proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension.Type,
+ &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(50),
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
- }),
+ },
)
- m.ProtoReflect().Set(
- testpb.E_RepeatedFixed32Extension.Type,
- testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3, 4, 5, 6}),
- )
+ proto.SetExtension(m, testpb.E_RepeatedFixed32Extension.Type, &[]uint32{1, 2, 3, 4, 5, 6})
return m
}(),
}, {
diff --git a/proto/messageset.go b/proto/messageset.go
index 1c6ac29..e5d4bd5 100644
--- a/proto/messageset.go
+++ b/proto/messageset.go
@@ -71,14 +71,15 @@
if !md.ExtensionRanges().Has(num) {
return errUnknown
}
- fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
+ xt, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
if err == protoregistry.NotFound {
return errUnknown
}
if err != nil {
return err
}
- if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil {
+ xd := xt.Descriptor()
+ if err := o.unmarshalMessage(v, m.Mutable(xd).Message()); err != nil {
// Contents cannot be unmarshaled.
return err
}
diff --git a/reflect/protoreflect/type.go b/reflect/protoreflect/type.go
index afd4cff..92b5750 100644
--- a/reflect/protoreflect/type.go
+++ b/reflect/protoreflect/type.go
@@ -229,8 +229,6 @@
// MessageType encapsulates a MessageDescriptor with a concrete Go implementation.
type MessageType interface {
- MessageDescriptor
-
// New returns a newly allocated empty message.
New() Message
@@ -401,6 +399,18 @@
// ExtensionDescriptor is an alias of FieldDescriptor for documentation.
type ExtensionDescriptor = FieldDescriptor
+// ExtensionTypeDescriptor is an ExtensionDescriptor with an associated ExtensionType.
+type ExtensionTypeDescriptor interface {
+ ExtensionDescriptor
+
+ // Type returns the associated ExtensionType.
+ Type() ExtensionType
+
+ // Descriptor returns the plain ExtensionDescriptor without the
+ // associated ExtensionType.
+ Descriptor() ExtensionDescriptor
+}
+
// ExtensionDescriptors is a list of field declarations.
type ExtensionDescriptors interface {
// Len reports the number of fields.
@@ -436,8 +446,6 @@
// Field "bar_field" is an extension of FooMessage, but its full name is
// "example.BarMessage.bar_field" instead of "example.FooMessage.bar_field".
type ExtensionType interface {
- ExtensionDescriptor
-
// New returns a new value for the field.
// For scalars, this returns the default value in native Go form.
New() Value
@@ -454,7 +462,7 @@
GoType() reflect.Type
// Descriptor returns the extension descriptor.
- Descriptor() ExtensionDescriptor
+ Descriptor() ExtensionTypeDescriptor
// TODO: What to do with nil?
// Should ValueOf(nil) return Value{}?
@@ -500,8 +508,6 @@
// EnumType encapsulates an EnumDescriptor with a concrete Go implementation.
type EnumType interface {
- EnumDescriptor
-
// New returns an instance of this enum type with its value set to n.
New(n EnumNumber) Enum
diff --git a/reflect/protoreflect/value.go b/reflect/protoreflect/value.go
index 3c9229b..ec10099 100644
--- a/reflect/protoreflect/value.go
+++ b/reflect/protoreflect/value.go
@@ -30,7 +30,7 @@
// Accessor/mutators for individual fields are keyed by FieldDescriptor.
// For non-extension fields, the descriptor must exactly match the
// field known by the parent message.
-// For extension fields, the descriptor must implement ExtensionType,
+// For extension fields, the descriptor must implement ExtensionTypeDescriptor,
// extend the parent message (i.e., have the same message FullName), and
// be within the parent's extension range.
//
diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go
index 0b60c64..22b5d8c 100644
--- a/reflect/protoregistry/registry.go
+++ b/reflect/protoregistry/registry.go
@@ -317,7 +317,6 @@
// Type is an interface satisfied by protoreflect.EnumType,
// protoreflect.MessageType, or protoreflect.ExtensionType.
type Type interface {
- protoreflect.Descriptor
GoType() reflect.Type
}
@@ -428,21 +427,22 @@
switch typ.(type) {
case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
// Check for conflicts in typesByName.
- var name protoreflect.FullName
+ var desc protoreflect.Descriptor
switch t := typ.(type) {
case protoreflect.EnumType:
- name = t.FullName()
+ desc = t.Descriptor()
case protoreflect.MessageType:
- name = t.FullName()
+ desc = t.Descriptor()
case protoreflect.ExtensionType:
- name = t.FullName()
+ desc = t.Descriptor()
default:
panic(fmt.Sprintf("invalid type: %T", t))
}
+ name := desc.FullName()
if prev := r.typesByName[name]; prev != nil {
err := errors.New("%v %v is already registered", typeName(typ), name)
err = amendErrorWithCaller(err, prev, typ)
- if r == GlobalTypes && ignoreConflict(typ, err) {
+ if r == GlobalTypes && ignoreConflict(desc, err) {
err = nil
}
if firstErr == nil {
@@ -453,12 +453,13 @@
// Check for conflicts in extensionsByMessage.
if xt, _ := typ.(protoreflect.ExtensionType); xt != nil {
- field := xt.Number()
- message := xt.ContainingMessage().FullName()
+ xd := xt.Descriptor()
+ field := xd.Number()
+ message := xd.ContainingMessage().FullName()
if prev := r.extensionsByMessage[message][field]; prev != nil {
err := errors.New("extension number %d is already registered on message %v", field, message)
err = amendErrorWithCaller(err, prev, typ)
- if r == GlobalTypes && ignoreConflict(typ, err) {
+ if r == GlobalTypes && ignoreConflict(xd, err) {
err = nil
}
if firstErr == nil {
diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go
index 63dd1ba..1572b63 100644
--- a/reflect/protoregistry/registry_test.go
+++ b/reflect/protoregistry/registry_test.go
@@ -536,11 +536,11 @@
fullName := func(t preg.Type) pref.FullName {
switch t := t.(type) {
case pref.EnumType:
- return t.FullName()
+ return t.Descriptor().FullName()
case pref.MessageType:
- return t.FullName()
+ return t.Descriptor().FullName()
case pref.ExtensionType:
- return t.FullName()
+ return t.Descriptor().FullName()
default:
panic("invalid type")
}
diff --git a/runtime/protoiface/legacy.go b/runtime/protoiface/legacy.go
index 4f8d71f..d7acb33 100644
--- a/runtime/protoiface/legacy.go
+++ b/runtime/protoiface/legacy.go
@@ -5,6 +5,8 @@
package protoiface
import (
+ "reflect"
+
"google.golang.org/protobuf/reflect/protoreflect"
)
@@ -64,3 +66,31 @@
// protoreflect.FileDescriptor.Path.
Filename string
}
+
+func (e ExtensionDescV1) getType() protoreflect.ExtensionType {
+ if e.Type != nil {
+ return e.Type
+ }
+ // All ExtensionDescV1 instances in generated code should have
+ // an Type field initialized at init time, so this case only
+ // occurs for non-standard generated code and hand-written
+ // ExtensionDescs.
+ panic(`proto: ExtensionDesc.Type is not set.
+
+This error probably indicates that you are trying to use a non-standard
+"github.com/golang/protobuf/proto".ExtensionDesc with the
+"google.golang.org/golang/protobuf" API. Use a protoreflect.ExtensionType
+instead.
+`)
+}
+
+func (e ExtensionDescV1) New() protoreflect.Value { return e.getType().New() }
+func (e ExtensionDescV1) Zero() protoreflect.Value { return e.getType().Zero() }
+func (e ExtensionDescV1) GoType() reflect.Type { return e.getType().GoType() }
+func (e ExtensionDescV1) Descriptor() protoreflect.ExtensionTypeDescriptor {
+ return e.getType().Descriptor()
+}
+func (e ExtensionDescV1) ValueOf(x interface{}) protoreflect.Value { return e.getType().ValueOf(x) }
+func (e ExtensionDescV1) InterfaceOf(x protoreflect.Value) interface{} {
+ return e.getType().InterfaceOf(x)
+}
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index fded0ff..9ff58f9 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -47,7 +47,7 @@
})
}
for _, xt := range opts.ExtensionTypes {
- testField(t, m1, xt)
+ testField(t, m1, xt.Descriptor())
}
for i := 0; i < md.Oneofs().Len(); i++ {
testOneof(t, m1, md.Oneofs().Get(i))
@@ -57,12 +57,12 @@
// Test round-trip marshal/unmarshal.
m2 := m.ProtoReflect().New().Interface()
populateMessage(m2.ProtoReflect(), 1, nil)
- b, err := proto.Marshal(m2)
+ b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2)
if err != nil {
t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
}
m3 := m.ProtoReflect().New().Interface()
- if err := proto.Unmarshal(b, m3); err != nil {
+ if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil {
t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
}
if !proto.Equal(m2, m3) {
diff --git a/testing/prototest/prototest_test.go b/testing/prototest/prototest_test.go
index a95ac0b..00df3a6 100644
--- a/testing/prototest/prototest_test.go
+++ b/testing/prototest/prototest_test.go
@@ -10,9 +10,12 @@
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/testing/prototest"
irregularpb "google.golang.org/protobuf/internal/testprotos/irregular"
+ legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
+ legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
testpb "google.golang.org/protobuf/internal/testprotos/test"
_ "google.golang.org/protobuf/internal/testprotos/test/weak1"
_ "google.golang.org/protobuf/internal/testprotos/test/weak2"
@@ -26,6 +29,8 @@
(*testpb.TestRequired)(nil),
(*irregularpb.Message)(nil),
(*testpb.TestAllExtensions)(nil),
+ (*legacypb.Legacy)(nil),
+ protoimpl.X.MessageOf((*legacy1pb.Message)(nil)).Interface(),
}
if flags.Proto1Legacy {
ms = append(ms, (*testpb.TestWeak)(nil))
diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go
index 8794167..fa9155b 100644
--- a/types/dynamicpb/dynamic.go
+++ b/types/dynamicpb/dynamic.go
@@ -170,7 +170,7 @@
switch {
case fd.IsExtension():
// Call InterfaceOf just to let the extension typecheck the value.
- _ = fd.(pref.ExtensionType).InterfaceOf(v)
+ _ = fd.(pref.ExtensionTypeDescriptor).Type().InterfaceOf(v)
m.ext[fd.Number()] = fd
case fd.IsMap():
if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
@@ -217,7 +217,7 @@
m.checkField(fd)
switch {
case fd.IsExtension():
- return fd.(pref.ExtensionType).New()
+ return fd.(pref.ExtensionTypeDescriptor).Type().New()
case fd.IsMap():
return pref.ValueOf(&dynamicMap{
desc: fd,
@@ -258,8 +258,8 @@
func (m *Message) checkField(fd pref.FieldDescriptor) {
if fd.IsExtension() && fd.ContainingMessage().FullName() == m.Descriptor().FullName() {
- if _, ok := fd.(pref.ExtensionType); !ok {
- panic(errors.New("%v: extension field descriptor does not implement ExtensionType", fd.FullName()))
+ if _, ok := fd.(pref.ExtensionTypeDescriptor); !ok {
+ panic(errors.New("%v: extension field descriptor does not implement ExtensionTypeDescriptor", fd.FullName()))
}
return
}