internal/impl: implement oneof fields

Dynamically generate functions for handling individual fields within an oneof.
This implementation uses Go reflection to interact with the currently generated
approach, which uses an interface that can only be set by a limited set of
wrapper structs.

Change-Id: Ic848df922d6547411a15c4a20bfbbcae362da5c0
Reviewed-on: https://go-review.googlesource.com/c/142895
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 13babe0..5a5027f 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -118,7 +118,7 @@
 			continue fieldLoop
 		}
 	}
-	if fn, ok := t.MethodByName("XXX_OneofFuncs"); ok {
+	if fn, ok := reflect.PtrTo(t).MethodByName("XXX_OneofFuncs"); ok {
 		vs := fn.Func.Call([]reflect.Value{reflect.New(fn.Type.In(0)).Elem()})[3]
 	oneofLoop:
 		for _, v := range vs.Interface().([]interface{}) {
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index 1525ffd..bdab90a 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -31,8 +31,74 @@
 }
 
 func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot reflect.Type) fieldInfo {
-	// TODO: support oneof fields.
-	panic(fmt.Sprintf("invalid field: %v", fd))
+	ft := fs.Type
+	if ft.Kind() != reflect.Interface {
+		panic(fmt.Sprintf("invalid type: got %v, want interface kind", ft))
+	}
+	if ot.Kind() != reflect.Struct {
+		panic(fmt.Sprintf("invalid type: got %v, want struct kind", ot))
+	}
+	if !reflect.PtrTo(ot).Implements(ft) {
+		panic(fmt.Sprintf("invalid type: %v does not implement %v", ot, ft))
+	}
+	conv := matchGoTypePBKind(ot.Field(0).Type, fd.Kind())
+	fieldOffset := offsetOf(fs)
+	// TODO: Implement unsafe fast path?
+	return fieldInfo{
+		// NOTE: The logic below intentionally assumes that oneof fields are
+		// well-formatted. That is, the oneof interface never contains a
+		// typed nil pointer to one of the wrapper structs.
+
+		has: func(p pointer) bool {
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				return false
+			}
+			return true
+		},
+		get: func(p pointer) pref.Value {
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				if fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind {
+					// Return a typed nil pointer of the message type to be
+					// consistent with the behavior of generated getters.
+					rv = reflect.Zero(ot.Field(0).Type)
+					return conv.toPB(rv)
+				}
+				return fd.Default()
+			}
+			rv = rv.Elem().Elem().Field(0)
+			return conv.toPB(rv)
+		},
+		set: func(p pointer, v pref.Value) {
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				rv.Set(reflect.New(ot))
+			}
+			rv = rv.Elem().Elem().Field(0)
+			rv.Set(conv.toGo(v))
+		},
+		clear: func(p pointer) {
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				return
+			}
+			rv.Set(reflect.Zero(rv.Type()))
+		},
+		mutable: func(p pointer) pref.Mutable {
+			// Mutable is only valid for messages and panics for other kinds.
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				rv.Set(reflect.New(ot))
+			}
+			rv = rv.Elem().Elem().Field(0)
+			if rv.IsNil() {
+				pv := pref.ValueOf(conv.newMessage())
+				rv.Set(conv.toGo(pv))
+			}
+			return rv.Interface().(pref.Message)
+		},
+	}
 }
 
 func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index 79bcbd5..2fe53ed 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -13,6 +13,7 @@
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 
+	"github.com/golang/protobuf/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
 )
@@ -614,6 +615,151 @@
 	})
 }
 
+type (
+	OneofScalars struct {
+		Union isOneofScalars_Union `protobuf_oneof:"union"`
+	}
+	isOneofScalars_Union interface {
+		isOneofScalars_Union()
+	}
+
+	OneofScalars_Bool struct {
+		Bool bool `protobuf:"1"`
+	}
+	OneofScalars_Int32 struct {
+		Int32 MyInt32 `protobuf:"2"`
+	}
+	OneofScalars_Int64 struct {
+		Int64 int64 `protobuf:"3"`
+	}
+	OneofScalars_Uint32 struct {
+		Uint32 MyUint32 `protobuf:"4"`
+	}
+	OneofScalars_Uint64 struct {
+		Uint64 uint64 `protobuf:"5"`
+	}
+	OneofScalars_Float32 struct {
+		Float32 MyFloat32 `protobuf:"6"`
+	}
+	OneofScalars_Float64 struct {
+		Float64 float64 `protobuf:"7"`
+	}
+	OneofScalars_String struct {
+		String string `protobuf:"8"`
+	}
+	OneofScalars_StringA struct {
+		StringA []byte `protobuf:"9"`
+	}
+	OneofScalars_StringB struct {
+		StringB MyString `protobuf:"10"`
+	}
+	OneofScalars_Bytes struct {
+		Bytes []byte `protobuf:"11"`
+	}
+	OneofScalars_BytesA struct {
+		BytesA string `protobuf:"12"`
+	}
+	OneofScalars_BytesB struct {
+		BytesB MyBytes `protobuf:"13"`
+	}
+)
+
+func (*OneofScalars) XXX_OneofFuncs() (func(proto.Message, *proto.Buffer) error, func(proto.Message, int, int, *proto.Buffer) (bool, error), func(proto.Message) int, []interface{}) {
+	return nil, nil, nil, []interface{}{
+		(*OneofScalars_Bool)(nil),
+		(*OneofScalars_Int32)(nil),
+		(*OneofScalars_Int64)(nil),
+		(*OneofScalars_Uint32)(nil),
+		(*OneofScalars_Uint64)(nil),
+		(*OneofScalars_Float32)(nil),
+		(*OneofScalars_Float64)(nil),
+		(*OneofScalars_String)(nil),
+		(*OneofScalars_StringA)(nil),
+		(*OneofScalars_StringB)(nil),
+		(*OneofScalars_Bytes)(nil),
+		(*OneofScalars_BytesA)(nil),
+		(*OneofScalars_BytesB)(nil),
+	}
+}
+
+func (*OneofScalars_Bool) isOneofScalars_Union()    {}
+func (*OneofScalars_Int32) isOneofScalars_Union()   {}
+func (*OneofScalars_Int64) isOneofScalars_Union()   {}
+func (*OneofScalars_Uint32) isOneofScalars_Union()  {}
+func (*OneofScalars_Uint64) isOneofScalars_Union()  {}
+func (*OneofScalars_Float32) isOneofScalars_Union() {}
+func (*OneofScalars_Float64) isOneofScalars_Union() {}
+func (*OneofScalars_String) isOneofScalars_Union()  {}
+func (*OneofScalars_StringA) isOneofScalars_Union() {}
+func (*OneofScalars_StringB) isOneofScalars_Union() {}
+func (*OneofScalars_Bytes) isOneofScalars_Union()   {}
+func (*OneofScalars_BytesA) isOneofScalars_Union()  {}
+func (*OneofScalars_BytesB) isOneofScalars_Union()  {}
+
+func TestOneofs(t *testing.T) {
+	mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
+		Syntax:   pref.Proto2,
+		FullName: "ScalarProto2",
+		Fields: []ptype.Field{
+			{Name: "f1", Number: 1, Cardinality: pref.Optional, Kind: pref.BoolKind, Default: V(bool(true)), OneofName: "union"},
+			{Name: "f2", Number: 2, Cardinality: pref.Optional, Kind: pref.Int32Kind, Default: V(int32(2)), OneofName: "union"},
+			{Name: "f3", Number: 3, Cardinality: pref.Optional, Kind: pref.Int64Kind, Default: V(int64(3)), OneofName: "union"},
+			{Name: "f4", Number: 4, Cardinality: pref.Optional, Kind: pref.Uint32Kind, Default: V(uint32(4)), OneofName: "union"},
+			{Name: "f5", Number: 5, Cardinality: pref.Optional, Kind: pref.Uint64Kind, Default: V(uint64(5)), OneofName: "union"},
+			{Name: "f6", Number: 6, Cardinality: pref.Optional, Kind: pref.FloatKind, Default: V(float32(6)), OneofName: "union"},
+			{Name: "f7", Number: 7, Cardinality: pref.Optional, Kind: pref.DoubleKind, Default: V(float64(7)), OneofName: "union"},
+			{Name: "f8", Number: 8, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("8")), OneofName: "union"},
+			{Name: "f9", Number: 9, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("9")), OneofName: "union"},
+			{Name: "f10", Number: 10, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("10")), OneofName: "union"},
+			{Name: "f11", Number: 11, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("11")), OneofName: "union"},
+			{Name: "f12", Number: 12, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("12")), OneofName: "union"},
+			{Name: "f13", Number: 13, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("13")), OneofName: "union"},
+		},
+		Oneofs: []ptype.Oneof{{Name: "union"}},
+	})}
+
+	empty := mi.MessageOf(&OneofScalars{})
+	want1 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Bool{true}})
+	want2 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Int32{20}})
+	want3 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Int64{30}})
+	want4 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Uint32{40}})
+	want5 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Uint64{50}})
+	want6 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Float32{60}})
+	want7 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Float64{70}})
+	want8 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_String{string("80")}})
+	want9 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_StringA{[]byte("90")}})
+	want10 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_StringB{MyString("100")}})
+	want11 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Bytes{[]byte("110")}})
+	want12 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_BytesA{string("120")}})
+	want13 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_BytesB{MyBytes("130")}})
+
+	testMessage(t, nil, mi.MessageOf(&OneofScalars{}), messageOps{
+		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
+		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("13"))},
+
+		setFields{1: V(bool(true))}, hasFields{1: true}, equalMessage(want1),
+		setFields{2: V(int32(20))}, hasFields{2: true}, equalMessage(want2),
+		setFields{3: V(int64(30))}, hasFields{3: true}, equalMessage(want3),
+		setFields{4: V(uint32(40))}, hasFields{4: true}, equalMessage(want4),
+		setFields{5: V(uint64(50))}, hasFields{5: true}, equalMessage(want5),
+		setFields{6: V(float32(60))}, hasFields{6: true}, equalMessage(want6),
+		setFields{7: V(float64(70))}, hasFields{7: true}, equalMessage(want7),
+		setFields{8: V(string("80"))}, hasFields{8: true}, equalMessage(want8),
+		setFields{9: V(string("90"))}, hasFields{9: true}, equalMessage(want9),
+		setFields{10: V(string("100"))}, hasFields{10: true}, equalMessage(want10),
+		setFields{11: V([]byte("110"))}, hasFields{11: true}, equalMessage(want11),
+		setFields{12: V([]byte("120"))}, hasFields{12: true}, equalMessage(want12),
+		setFields{13: V([]byte("130"))}, hasFields{13: true}, equalMessage(want13),
+
+		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: true},
+		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("130"))},
+		clearFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true},
+		equalMessage(want13),
+		clearFields{13: true},
+		equalMessage(empty),
+	})
+}
+
 // TODO: Need to test singular and repeated messages
 
 var cmpOpts = cmp.Options{