reflect/protoreflect: add KnownFields.WhichOneof

Add a method that provides efficiently querying for which member field
in a oneof is actually set. This is useful when dealing with oneofs
with many member fields.

Change-Id: I918b566c432f8bdd24dcecbb5501d231ffefef29
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170580
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go
index 424ee3f..ba9be0d 100644
--- a/internal/impl/legacy_extension.go
+++ b/internal/impl/legacy_extension.go
@@ -119,6 +119,10 @@
 	p.x.Set(n, x)
 }
 
+func (p legacyExtensionFields) WhichOneof(pref.Name) pref.FieldNumber {
+	return 0
+}
+
 func (p legacyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
 	p.x.Range(func(n pref.FieldNumber, x ExtensionFieldV1) bool {
 		if p.Has(n) {
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 0b2c3fc..c719e47 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -33,6 +33,7 @@
 	// TODO: Split fields into dense and sparse maps similar to the current
 	// table-driven implementation in v1?
 	fields map[pref.FieldNumber]*fieldInfo
+	oneofs map[pref.Name]*oneofInfo
 
 	unknownFields   func(*messageDataType) pref.UnknownFields
 	extensionFields func(*messageDataType) pref.KnownFields
@@ -59,27 +60,30 @@
 // any discrepancies.
 func (mi *MessageType) makeKnownFieldsFunc(t reflect.Type) {
 	// Generate a mapping of field numbers and names to Go struct field or type.
-	fields := map[pref.FieldNumber]reflect.StructField{}
-	oneofs := map[pref.Name]reflect.StructField{}
-	oneofFields := map[pref.FieldNumber]reflect.Type{}
-	special := map[string]reflect.StructField{}
+	var (
+		fieldsByNumber        = map[pref.FieldNumber]reflect.StructField{}
+		oneofsByName          = map[pref.Name]reflect.StructField{}
+		oneofWrappersByType   = map[reflect.Type]pref.FieldNumber{}
+		oneofWrappersByNumber = map[pref.FieldNumber]reflect.Type{}
+		specialByName         = map[string]reflect.StructField{}
+	)
 fieldLoop:
 	for i := 0; i < t.NumField(); i++ {
 		f := t.Field(i)
 		for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
 			if len(s) > 0 && strings.Trim(s, "0123456789") == "" {
 				n, _ := strconv.ParseUint(s, 10, 64)
-				fields[pref.FieldNumber(n)] = f
+				fieldsByNumber[pref.FieldNumber(n)] = f
 				continue fieldLoop
 			}
 		}
 		if s := f.Tag.Get("protobuf_oneof"); len(s) > 0 {
-			oneofs[pref.Name(s)] = f
+			oneofsByName[pref.Name(s)] = f
 			continue fieldLoop
 		}
 		switch f.Name {
 		case "XXX_weak", "XXX_unrecognized", "XXX_sizecache", "XXX_extensions", "XXX_InternalExtensions":
-			special[f.Name] = f
+			specialByName[f.Name] = f
 			continue fieldLoop
 		}
 	}
@@ -96,7 +100,8 @@
 		for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
 			if len(s) > 0 && strings.Trim(s, "0123456789") == "" {
 				n, _ := strconv.ParseUint(s, 10, 64)
-				oneofFields[pref.FieldNumber(n)] = tf
+				oneofWrappersByType[tf] = pref.FieldNumber(n)
+				oneofWrappersByNumber[pref.FieldNumber(n)] = tf
 				break
 			}
 		}
@@ -105,13 +110,13 @@
 	mi.fields = map[pref.FieldNumber]*fieldInfo{}
 	for i := 0; i < mi.PBType.Fields().Len(); i++ {
 		fd := mi.PBType.Fields().Get(i)
-		fs := fields[fd.Number()]
+		fs := fieldsByNumber[fd.Number()]
 		var fi fieldInfo
 		switch {
 		case fd.IsWeak():
-			fi = fieldInfoForWeak(fd, special["XXX_weak"])
+			fi = fieldInfoForWeak(fd, specialByName["XXX_weak"])
 		case fd.OneofType() != nil:
-			fi = fieldInfoForOneof(fd, oneofs[fd.OneofType().Name()], oneofFields[fd.Number()])
+			fi = fieldInfoForOneof(fd, oneofsByName[fd.OneofType().Name()], oneofWrappersByNumber[fd.Number()])
 		case fd.IsMap():
 			fi = fieldInfoForMap(fd, fs)
 		case fd.Cardinality() == pref.Repeated:
@@ -123,6 +128,12 @@
 		}
 		mi.fields[fd.Number()] = &fi
 	}
+
+	mi.oneofs = map[pref.Name]*oneofInfo{}
+	for i := 0; i < mi.PBType.Oneofs().Len(); i++ {
+		od := mi.PBType.Oneofs().Get(i)
+		mi.oneofs[od.Name()] = makeOneofInfo(od, oneofsByName[od.Name()], oneofWrappersByType)
+	}
 }
 
 func (mi *MessageType) makeUnknownFieldsFunc(t reflect.Type) {
@@ -268,6 +279,12 @@
 		return
 	}
 }
+func (fs *knownFields) WhichOneof(s pref.Name) pref.FieldNumber {
+	if oi := fs.mi.oneofs[s]; oi != nil {
+		return oi.which(fs.p)
+	}
+	return 0
+}
 func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
 	for n, fi := range fs.mi.fields {
 		if fi.has(fs.p) {
@@ -309,6 +326,7 @@
 func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value               { return pref.Value{} }
 func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value)              { panic("extensions not supported") }
 func (emptyExtensionFields) Clear(pref.FieldNumber)                        { return } // noop
+func (emptyExtensionFields) WhichOneof(pref.Name) pref.FieldNumber         { return 0 }
 func (emptyExtensionFields) Range(func(pref.FieldNumber, pref.Value) bool) { return }
 func (emptyExtensionFields) NewMessage(pref.FieldNumber) pref.Message {
 	panic("extensions not supported")
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index 0ef8b12..581481c 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -290,3 +290,23 @@
 	}
 	return pv
 }
+
+type oneofInfo struct {
+	which func(pointer) pref.FieldNumber
+}
+
+func makeOneofInfo(od pref.OneofDescriptor, fs reflect.StructField, wrappersByType map[reflect.Type]pref.FieldNumber) *oneofInfo {
+	fieldOffset := offsetOf(fs)
+	return &oneofInfo{
+		which: func(p pointer) pref.FieldNumber {
+			if p.IsNil() {
+				return 0
+			}
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			if rv.IsNil() {
+				return 0
+			}
+			return wrappersByType[rv.Elem().Type().Elem()]
+		},
+	}
+}
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index 71bcb70..ef71ca1 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -52,6 +52,8 @@
 	setFields map[pref.FieldNumber]pref.Value
 	// clear specific fields in the message
 	clearFields []pref.FieldNumber
+	// check for the presence of specific oneof member fields.
+	whichOneofs map[pref.Name]pref.FieldNumber
 	// apply messageOps on each specified message field
 	messageFields map[pref.FieldNumber]messageOps
 	// apply listOps on each specified list field
@@ -67,6 +69,7 @@
 func (getFields) isMessageOp()     {}
 func (setFields) isMessageOp()     {}
 func (clearFields) isMessageOp()   {}
+func (whichOneofs) isMessageOp()   {}
 func (messageFields) isMessageOp() {}
 func (listFields) isMessageOp()    {}
 func (mapFields) isMessageOp()     {}
@@ -919,6 +922,7 @@
 	testMessage(t, nil, &OneofScalars{}, messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
 		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("13"))},
+		whichOneofs{"union": 0, "Union": 0},
 
 		setFields{1: V(bool(true))}, hasFields{1: true}, equalMessage{want1},
 		setFields{2: V(int32(20))}, hasFields{2: true}, equalMessage{want2},
@@ -927,6 +931,10 @@
 		setFields{5: V(uint64(50))}, hasFields{5: true}, equalMessage{want5},
 		setFields{6: V(float32(60))}, hasFields{6: true}, equalMessage{want6},
 		setFields{7: V(float64(70))}, hasFields{7: true}, equalMessage{want7},
+
+		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: true, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
+		whichOneofs{"union": 7, "Union": 0},
+
 		setFields{8: V(string("80"))}, hasFields{8: true}, equalMessage{want8},
 		setFields{9: V(string("90"))}, hasFields{9: true}, equalMessage{want9},
 		setFields{10: V(string("100"))}, hasFields{10: true}, equalMessage{want10},
@@ -937,8 +945,10 @@
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: true},
 		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("130"))},
 		clearFields{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
+		whichOneofs{"union": 13, "Union": 0},
 		equalMessage{want13},
 		clearFields{13},
+		whichOneofs{"union": 0, "Union": 0},
 		equalMessage{empty},
 	})
 
@@ -1240,6 +1250,15 @@
 			for _, n := range op {
 				fs.Clear(n)
 			}
+		case whichOneofs:
+			got := map[pref.Name]pref.FieldNumber{}
+			want := map[pref.Name]pref.FieldNumber(op)
+			for s := range want {
+				got[s] = fs.WhichOneof(s)
+			}
+			if diff := cmp.Diff(want, got); diff != "" {
+				t.Errorf("operation %v, KnownFields.WhichOneof mismatch (-want, +got):\n%s", p, diff)
+			}
 		case messageFields:
 			for n, tt := range op {
 				p.Push(int(n))
diff --git a/reflect/protoreflect/value.go b/reflect/protoreflect/value.go
index d1b5ba5..6d21627 100644
--- a/reflect/protoreflect/value.go
+++ b/reflect/protoreflect/value.go
@@ -92,6 +92,10 @@
 	// a known field or extension field.
 	Clear(FieldNumber)
 
+	// WhichOneof reports which field within the named oneof is populated.
+	// It returns 0 if the oneof does not exist or no fields are populated.
+	WhichOneof(Name) FieldNumber
+
 	// Range iterates over every populated field in an undefined order,
 	// calling f for each field number and value encountered.
 	// Range calls f Len times unless f returns false, which stops iteration.