internal/order: add a package for ordered iteration over messages and maps

The order package replaces the mapsort and fieldsort packages.
It presents a common API for ordered iteration over message fields
and map fields.

It has a number of pre-defined orderings.

Change-Id: Ie6cd423da30b4757864c352cb04454f21fe07ee2
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/239837
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/encoding/protojson/encode.go b/encoding/protojson/encode.go
index 7d61933..58bdebe 100644
--- a/encoding/protojson/encode.go
+++ b/encoding/protojson/encode.go
@@ -7,13 +7,13 @@
 import (
 	"encoding/base64"
 	"fmt"
-	"sort"
 
 	"google.golang.org/protobuf/internal/encoding/json"
 	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/genid"
+	"google.golang.org/protobuf/internal/order"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
@@ -160,61 +160,71 @@
 	return nil
 }
 
+// unpopulatedFieldRanger wraps a protoreflect.Message and modifies its Range
+// method to additionally iterate over unpopulated fields.
+type unpopulatedFieldRanger struct{ pref.Message }
+
+func (m unpopulatedFieldRanger) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
+	fds := m.Descriptor().Fields()
+	for i := 0; i < fds.Len(); i++ {
+		fd := fds.Get(i)
+		if m.Has(fd) || fd.ContainingOneof() != nil {
+			continue // ignore populated fields and fields within a oneofs
+		}
+
+		v := m.Get(fd)
+		isProto2Scalar := fd.Syntax() == pref.Proto2 && fd.Default().IsValid()
+		isSingularMessage := fd.Cardinality() != pref.Repeated && fd.Message() != nil
+		if isProto2Scalar || isSingularMessage {
+			v = pref.Value{} // use invalid value to emit null
+		}
+		if !f(fd, v) {
+			return
+		}
+	}
+	m.Message.Range(f)
+}
+
 // marshalFields marshals the fields in the given protoreflect.Message.
 func (e encoder) marshalFields(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
+	if !flags.ProtoLegacy && messageset.IsMessageSet(m.Descriptor()) {
 		return errors.New("no support for proto1 MessageSets")
 	}
 
-	// Marshal out known fields.
-	fieldDescs := messageDesc.Fields()
-	for i := 0; i < fieldDescs.Len(); {
-		fd := fieldDescs.Get(i)
-		if od := fd.ContainingOneof(); od != nil {
-			fd = m.WhichOneof(od)
-			i += od.Fields().Len()
-			if fd == nil {
-				continue // unpopulated oneofs are not affected by EmitUnpopulated
-			}
-		} else {
-			i++
-		}
+	var fields order.FieldRanger = m
+	if e.opts.EmitUnpopulated {
+		fields = unpopulatedFieldRanger{m}
+	}
 
-		val := m.Get(fd)
-		if !m.Has(fd) {
-			if !e.opts.EmitUnpopulated {
-				continue
+	var err error
+	order.RangeFields(fields, order.IndexNameFieldOrder, func(fd pref.FieldDescriptor, v pref.Value) bool {
+		var name string
+		switch {
+		case fd.IsExtension():
+			if messageset.IsMessageSetExtension(fd) {
+				name = "[" + string(fd.FullName().Parent()) + "]"
+			} else {
+				name = "[" + string(fd.FullName()) + "]"
 			}
-			isProto2Scalar := fd.Syntax() == pref.Proto2 && fd.Default().IsValid()
-			isSingularMessage := fd.Cardinality() != pref.Repeated && fd.Message() != nil
-			if isProto2Scalar || isSingularMessage {
-				// Use invalid value to emit null.
-				val = pref.Value{}
-			}
-		}
-
-		name := fd.JSONName()
-		if e.opts.UseProtoNames {
-			name = string(fd.Name())
-			// Use type name for group field name.
+		case e.opts.UseProtoNames:
 			if fd.Kind() == pref.GroupKind {
 				name = string(fd.Message().Name())
+			} else {
+				name = string(fd.Name())
 			}
+		default:
+			name = fd.JSONName()
 		}
-		if err := e.WriteName(name); err != nil {
-			return err
-		}
-		if err := e.marshalValue(val, fd); err != nil {
-			return err
-		}
-	}
 
-	// Marshal out extensions.
-	if err := e.marshalExtensions(m); err != nil {
-		return err
-	}
-	return nil
+		if err = e.WriteName(name); err != nil {
+			return false
+		}
+		if err = e.marshalValue(v, fd); err != nil {
+			return false
+		}
+		return true
+	})
+	return err
 }
 
 // marshalValue marshals the given protoreflect.Value.
@@ -305,98 +315,20 @@
 	return nil
 }
 
-type mapEntry struct {
-	key   pref.MapKey
-	value pref.Value
-}
-
 // marshalMap marshals given protoreflect.Map.
 func (e encoder) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
 	e.StartObject()
 	defer e.EndObject()
 
-	// Get a sorted list based on keyType first.
-	entries := make([]mapEntry, 0, mmap.Len())
-	mmap.Range(func(key pref.MapKey, val pref.Value) bool {
-		entries = append(entries, mapEntry{key: key, value: val})
+	var err error
+	order.RangeEntries(mmap, order.GenericKeyOrder, func(k pref.MapKey, v pref.Value) bool {
+		if err = e.WriteName(k.String()); err != nil {
+			return false
+		}
+		if err = e.marshalSingular(v, fd.MapValue()); err != nil {
+			return false
+		}
 		return true
 	})
-	sortMap(fd.MapKey().Kind(), entries)
-
-	// Write out sorted list.
-	for _, entry := range entries {
-		if err := e.WriteName(entry.key.String()); err != nil {
-			return err
-		}
-		if err := e.marshalSingular(entry.value, fd.MapValue()); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-// sortMap orders list based on value of key field for deterministic ordering.
-func sortMap(keyKind pref.Kind, values []mapEntry) {
-	sort.Slice(values, func(i, j int) bool {
-		switch keyKind {
-		case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
-			pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
-			return values[i].key.Int() < values[j].key.Int()
-
-		case pref.Uint32Kind, pref.Fixed32Kind,
-			pref.Uint64Kind, pref.Fixed64Kind:
-			return values[i].key.Uint() < values[j].key.Uint()
-		}
-		return values[i].key.String() < values[j].key.String()
-	})
-}
-
-// marshalExtensions marshals extension fields.
-func (e encoder) marshalExtensions(m pref.Message) error {
-	type entry struct {
-		key   string
-		value pref.Value
-		desc  pref.FieldDescriptor
-	}
-
-	// Get a sorted list based on field key first.
-	var entries []entry
-	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
-		if !fd.IsExtension() {
-			return true
-		}
-
-		// For MessageSet extensions, the name used is the parent message.
-		name := fd.FullName()
-		if messageset.IsMessageSetExtension(fd) {
-			name = name.Parent()
-		}
-
-		// Use [name] format for JSON field name.
-		entries = append(entries, entry{
-			key:   string(name),
-			value: v,
-			desc:  fd,
-		})
-		return true
-	})
-
-	// Sort extensions lexicographically.
-	sort.Slice(entries, func(i, j int) bool {
-		return entries[i].key < entries[j].key
-	})
-
-	// Write out sorted list.
-	for _, entry := range entries {
-		// JSON field name is the proto field name enclosed in [], similar to
-		// textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not
-		// marshal out extension fields.
-		if err := e.WriteName("[" + entry.key + "]"); err != nil {
-			return err
-		}
-		if err := e.marshalValue(entry.value, entry.desc); err != nil {
-			return err
-		}
-	}
-	return nil
+	return err
 }
diff --git a/encoding/protojson/encode_test.go b/encoding/protojson/encode_test.go
index b4cf3bb..0cca937 100644
--- a/encoding/protojson/encode_test.go
+++ b/encoding/protojson/encode_test.go
@@ -1060,12 +1060,12 @@
 			return m
 		}(),
 		want: `{
-  "[pb2.MessageSetExtension]": {
-    "optString": "a messageset extension"
-  },
   "[pb2.MessageSetExtension.ext_nested]": {
     "optString": "just a regular extension"
   },
+  "[pb2.MessageSetExtension]": {
+    "optString": "a messageset extension"
+  },
   "[pb2.MessageSetExtension.not_message_set_extension]": {
     "optString": "not a messageset extension"
   }
@@ -2125,6 +2125,35 @@
   ]
 }`,
 	}, {
+		desc: "EmitUnpopulated: with populated fields",
+		mo:   protojson.MarshalOptions{EmitUnpopulated: true},
+		input: &pb2.Scalars{
+			OptInt32:    proto.Int32(0xff),
+			OptUint32:   proto.Uint32(47),
+			OptSint32:   proto.Int32(-1001),
+			OptFixed32:  proto.Uint32(32),
+			OptSfixed32: proto.Int32(-32),
+			OptFloat:    proto.Float32(1.02),
+			OptBytes:    []byte("谷歌"),
+		},
+		want: `{
+  "optBool": null,
+  "optInt32": 255,
+  "optInt64": null,
+  "optUint32": 47,
+  "optUint64": null,
+  "optSint32": -1001,
+  "optSint64": null,
+  "optFixed32": 32,
+  "optFixed64": null,
+  "optSfixed32": -32,
+  "optSfixed64": null,
+  "optFloat": 1.02,
+  "optDouble": null,
+  "optBytes": "6LC35q2M",
+  "optString": null
+}`,
+	}, {
 		desc: "UseEnumNumbers in singular field",
 		mo:   protojson.MarshalOptions{UseEnumNumbers: true},
 		input: &pb2.Enums{
diff --git a/encoding/prototext/encode.go b/encoding/prototext/encode.go
index 0877d71..3171156 100644
--- a/encoding/prototext/encode.go
+++ b/encoding/prototext/encode.go
@@ -6,7 +6,6 @@
 
 import (
 	"fmt"
-	"sort"
 	"strconv"
 	"unicode/utf8"
 
@@ -16,10 +15,11 @@
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/genid"
-	"google.golang.org/protobuf/internal/mapsort"
+	"google.golang.org/protobuf/internal/order"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/internal/strs"
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
 )
@@ -169,35 +169,30 @@
 		// If unable to expand, continue on to marshal Any as a regular message.
 	}
 
-	// Marshal known fields.
-	fieldDescs := messageDesc.Fields()
-	size := fieldDescs.Len()
-	for i := 0; i < size; {
-		fd := fieldDescs.Get(i)
-		if od := fd.ContainingOneof(); od != nil {
-			fd = m.WhichOneof(od)
-			i += od.Fields().Len()
+	// Marshal fields.
+	var err error
+	order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		var name string
+		if fd.IsExtension() {
+			if messageset.IsMessageSetExtension(fd) {
+				name = "[" + string(fd.FullName().Parent()) + "]"
+			} else {
+				name = "[" + string(fd.FullName()) + "]"
+			}
 		} else {
-			i++
+			if fd.Kind() == pref.GroupKind {
+				name = string(fd.Message().Name())
+			} else {
+				name = string(fd.Name())
+			}
 		}
 
-		if fd == nil || !m.Has(fd) {
-			continue
+		if err = e.marshalField(string(name), v, fd); err != nil {
+			return false
 		}
-
-		name := fd.Name()
-		// Use type name for group field name.
-		if fd.Kind() == pref.GroupKind {
-			name = fd.Message().Name()
-		}
-		val := m.Get(fd)
-		if err := e.marshalField(string(name), val, fd); err != nil {
-			return err
-		}
-	}
-
-	// Marshal extensions.
-	if err := e.marshalExtensions(m); err != nil {
+		return true
+	})
+	if err != nil {
 		return err
 	}
 
@@ -290,7 +285,7 @@
 // marshalMap marshals the given protoreflect.Map as multiple name-value fields.
 func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor) error {
 	var err error
-	mapsort.Range(mmap, fd.MapKey().Kind(), func(key pref.MapKey, val pref.Value) bool {
+	order.RangeEntries(mmap, order.GenericKeyOrder, func(key pref.MapKey, val pref.Value) bool {
 		e.WriteName(name)
 		e.StartMessage()
 		defer e.EndMessage()
@@ -311,48 +306,6 @@
 	return err
 }
 
-// marshalExtensions marshals extension fields.
-func (e encoder) marshalExtensions(m pref.Message) error {
-	type entry struct {
-		key   string
-		value pref.Value
-		desc  pref.FieldDescriptor
-	}
-
-	// Get a sorted list based on field key first.
-	var entries []entry
-	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
-		if !fd.IsExtension() {
-			return true
-		}
-		// For MessageSet extensions, the name used is the parent message.
-		name := fd.FullName()
-		if messageset.IsMessageSetExtension(fd) {
-			name = name.Parent()
-		}
-		entries = append(entries, entry{
-			key:   string(name),
-			value: v,
-			desc:  fd,
-		})
-		return true
-	})
-	// Sort extensions lexicographically.
-	sort.Slice(entries, func(i, j int) bool {
-		return entries[i].key < entries[j].key
-	})
-
-	// Write out sorted list.
-	for _, entry := range entries {
-		// Extension field name is the proto field name enclosed in [].
-		name := "[" + entry.key + "]"
-		if err := e.marshalField(name, entry.value, entry.desc); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
 // marshalUnknown parses the given []byte and marshals fields out.
 // This function assumes proper encoding in the given []byte.
 func (e encoder) marshalUnknown(b []byte) {
diff --git a/encoding/prototext/encode_test.go b/encoding/prototext/encode_test.go
index 4de385c..49fba14 100644
--- a/encoding/prototext/encode_test.go
+++ b/encoding/prototext/encode_test.go
@@ -1158,12 +1158,12 @@
 			})
 			return m
 		}(),
-		want: `[pb2.MessageSetExtension]: {
-  opt_string: "a messageset extension"
-}
-[pb2.MessageSetExtension.ext_nested]: {
+		want: `[pb2.MessageSetExtension.ext_nested]: {
   opt_string: "just a regular extension"
 }
+[pb2.MessageSetExtension]: {
+  opt_string: "a messageset extension"
+}
 [pb2.MessageSetExtension.not_message_set_extension]: {
   opt_string: "not a messageset extension"
 }
diff --git a/internal/fieldsort/fieldsort.go b/internal/fieldsort/fieldsort.go
deleted file mode 100644
index 517c4e2..0000000
--- a/internal/fieldsort/fieldsort.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package fieldsort defines an ordering of fields.
-//
-// The ordering defined by this package matches the historic behavior of the proto
-// package, placing extensions first and oneofs last.
-//
-// There is no guarantee about stability of the wire encoding, and users should not
-// depend on the order defined in this package as it is subject to change without
-// notice.
-package fieldsort
-
-import (
-	"google.golang.org/protobuf/reflect/protoreflect"
-)
-
-// Less returns true if field a comes before field j in ordered wire marshal output.
-func Less(a, b protoreflect.FieldDescriptor) bool {
-	ea := a.IsExtension()
-	eb := b.IsExtension()
-	oa := a.ContainingOneof()
-	ob := b.ContainingOneof()
-	switch {
-	case ea != eb:
-		return ea
-	case oa != nil && ob != nil:
-		if oa == ob {
-			return a.Number() < b.Number()
-		}
-		return oa.Index() < ob.Index()
-	case oa != nil && !oa.IsSynthetic():
-		return false
-	case ob != nil && !ob.IsSynthetic():
-		return true
-	default:
-		return a.Number() < b.Number()
-	}
-}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 0e176d5..733ad9f 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -11,7 +11,7 @@
 
 	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/internal/encoding/messageset"
-	"google.golang.org/protobuf/internal/fieldsort"
+	"google.golang.org/protobuf/internal/order"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
@@ -136,7 +136,7 @@
 		sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
 			fi := fields.ByNumber(mi.orderedCoderFields[i].num)
 			fj := fields.ByNumber(mi.orderedCoderFields[j].num)
-			return fieldsort.Less(fi, fj)
+			return order.LegacyFieldOrder(fi, fj)
 		})
 	}
 
diff --git a/internal/mapsort/mapsort.go b/internal/mapsort/mapsort.go
deleted file mode 100644
index a3de1cf..0000000
--- a/internal/mapsort/mapsort.go
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package mapsort provides sorted access to maps.
-package mapsort
-
-import (
-	"sort"
-
-	"google.golang.org/protobuf/reflect/protoreflect"
-)
-
-// Range iterates over every map entry in sorted key order,
-// calling f for each key and value encountered.
-func Range(mapv protoreflect.Map, keyKind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
-	var keys []protoreflect.MapKey
-	mapv.Range(func(key protoreflect.MapKey, _ protoreflect.Value) bool {
-		keys = append(keys, key)
-		return true
-	})
-	sort.Slice(keys, func(i, j int) bool {
-		switch keyKind {
-		case protoreflect.BoolKind:
-			return !keys[i].Bool() && keys[j].Bool()
-		case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
-			protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
-			return keys[i].Int() < keys[j].Int()
-		case protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
-			protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
-			return keys[i].Uint() < keys[j].Uint()
-		case protoreflect.StringKind:
-			return keys[i].String() < keys[j].String()
-		default:
-			panic("invalid kind: " + keyKind.String())
-		}
-	})
-	for _, key := range keys {
-		if !f(key, mapv.Get(key)) {
-			break
-		}
-	}
-}
diff --git a/internal/mapsort/mapsort_test.go b/internal/mapsort/mapsort_test.go
deleted file mode 100644
index 6d17946..0000000
--- a/internal/mapsort/mapsort_test.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package mapsort_test
-
-import (
-	"strconv"
-	"testing"
-
-	"google.golang.org/protobuf/internal/mapsort"
-	pref "google.golang.org/protobuf/reflect/protoreflect"
-
-	testpb "google.golang.org/protobuf/internal/testprotos/test"
-)
-
-func TestRange(t *testing.T) {
-	m := (&testpb.TestAllTypes{
-		MapBoolBool: map[bool]bool{
-			false: false,
-			true:  true,
-		},
-		MapInt32Int32: map[int32]int32{
-			0: 0,
-			1: 1,
-			2: 2,
-		},
-		MapUint64Uint64: map[uint64]uint64{
-			0: 0,
-			1: 1,
-			2: 2,
-		},
-		MapStringString: map[string]string{
-			"0": "0",
-			"1": "1",
-			"2": "2",
-		},
-	}).ProtoReflect()
-	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
-		mapv := v.Map()
-		var got []pref.MapKey
-		mapsort.Range(mapv, fd.MapKey().Kind(), func(key pref.MapKey, _ pref.Value) bool {
-			got = append(got, key)
-			return true
-		})
-		for wanti, key := range got {
-			var goti int
-			switch x := mapv.Get(key).Interface().(type) {
-			case bool:
-				if x {
-					goti = 1
-				}
-			case int32:
-				goti = int(x)
-			case uint64:
-				goti = int(x)
-			case string:
-				goti, _ = strconv.Atoi(x)
-			default:
-				t.Fatalf("unhandled map value type %T", x)
-			}
-			if wanti != goti {
-				t.Errorf("out of order range over map field %v: %v", fd.FullName(), got)
-				break
-			}
-		}
-		return true
-	})
-}
diff --git a/internal/msgfmt/format.go b/internal/msgfmt/format.go
index 9547a53..f01cf60 100644
--- a/internal/msgfmt/format.go
+++ b/internal/msgfmt/format.go
@@ -20,7 +20,7 @@
 	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/internal/detrand"
 	"google.golang.org/protobuf/internal/genid"
-	"google.golang.org/protobuf/internal/mapsort"
+	"google.golang.org/protobuf/internal/order"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
@@ -64,25 +64,8 @@
 		return b2
 	}
 
-	var fds []protoreflect.FieldDescriptor
-	m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
-		fds = append(fds, fd)
-		return true
-	})
-	sort.Slice(fds, func(i, j int) bool {
-		fdi, fdj := fds[i], fds[j]
-		switch {
-		case !fdi.IsExtension() && !fdj.IsExtension():
-			return fdi.Index() < fdj.Index()
-		case fdi.IsExtension() && fdj.IsExtension():
-			return fdi.FullName() < fdj.FullName()
-		default:
-			return !fdi.IsExtension() && fdj.IsExtension()
-		}
-	})
-
 	b = append(b, '{')
-	for _, fd := range fds {
+	order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 		k := string(fd.Name())
 		if fd.IsExtension() {
 			k = string("[" + fd.FullName() + "]")
@@ -90,9 +73,10 @@
 
 		b = append(b, k...)
 		b = append(b, ':')
-		b = appendValue(b, m.Get(fd), fd)
+		b = appendValue(b, v, fd)
 		b = append(b, delim()...)
-	}
+		return true
+	})
 	b = appendUnknown(b, m.GetUnknown())
 	b = bytes.TrimRight(b, delim())
 	b = append(b, '}')
@@ -247,19 +231,14 @@
 }
 
 func appendMap(b []byte, v protoreflect.Map, fd protoreflect.FieldDescriptor) []byte {
-	var ks []protoreflect.MapKey
-	mapsort.Range(v, fd.MapKey().Kind(), func(k protoreflect.MapKey, _ protoreflect.Value) bool {
-		ks = append(ks, k)
-		return true
-	})
-
 	b = append(b, '{')
-	for _, k := range ks {
+	order.RangeEntries(v, order.GenericKeyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool {
 		b = appendValue(b, k.Value(), fd.MapKey())
 		b = append(b, ':')
-		b = appendValue(b, v.Get(k), fd.MapValue())
+		b = appendValue(b, v, fd.MapValue())
 		b = append(b, delim()...)
-	}
+		return true
+	})
 	b = bytes.TrimRight(b, delim())
 	b = append(b, '}')
 	return b
diff --git a/internal/order/order.go b/internal/order/order.go
new file mode 100644
index 0000000..2a24953
--- /dev/null
+++ b/internal/order/order.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package order
+
+import (
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// FieldOrder specifies the ordering to visit message fields.
+// It is a function that reports whether x is ordered before y.
+type FieldOrder func(x, y pref.FieldDescriptor) bool
+
+var (
+	// AnyFieldOrder specifies no specific field ordering.
+	AnyFieldOrder FieldOrder = nil
+
+	// LegacyFieldOrder sorts fields in the same ordering as emitted by
+	// wire serialization in the github.com/golang/protobuf implementation.
+	LegacyFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
+		ox, oy := x.ContainingOneof(), y.ContainingOneof()
+		inOneof := func(od pref.OneofDescriptor) bool {
+			return od != nil && !od.IsSynthetic()
+		}
+
+		// Extension fields sort before non-extension fields.
+		if x.IsExtension() != y.IsExtension() {
+			return x.IsExtension() && !y.IsExtension()
+		}
+		// Fields not within a oneof sort before those within a oneof.
+		if inOneof(ox) != inOneof(oy) {
+			return !inOneof(ox) && inOneof(oy)
+		}
+		// Fields in disjoint oneof sets are sorted by declaration index.
+		if ox != nil && oy != nil && ox != oy {
+			return ox.Index() < oy.Index()
+		}
+		// Fields sorted by field number.
+		return x.Number() < y.Number()
+	}
+
+	// NumberFieldOrder sorts fields by their field number.
+	NumberFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
+		return x.Number() < y.Number()
+	}
+
+	// IndexNameFieldOrder sorts non-extension fields before extension fields.
+	// Non-extensions are sorted according to their declaration index.
+	// Extensions are sorted according to their full name.
+	IndexNameFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
+		// Non-extension fields sort before extension fields.
+		if x.IsExtension() != y.IsExtension() {
+			return !x.IsExtension() && y.IsExtension()
+		}
+		// Extensions sorted by fullname.
+		if x.IsExtension() && y.IsExtension() {
+			return x.FullName() < y.FullName()
+		}
+		// Non-extensions sorted by declaration index.
+		return x.Index() < y.Index()
+	}
+)
+
+// KeyOrder specifies the ordering to visit map entries.
+// It is a function that reports whether x is ordered before y.
+type KeyOrder func(x, y pref.MapKey) bool
+
+var (
+	// AnyKeyOrder specifies no specific key ordering.
+	AnyKeyOrder KeyOrder = nil
+
+	// GenericKeyOrder sorts false before true, numeric keys in ascending order,
+	// and strings in lexicographical ordering according to UTF-8 codepoints.
+	GenericKeyOrder KeyOrder = func(x, y pref.MapKey) bool {
+		switch x.Interface().(type) {
+		case bool:
+			return !x.Bool() && y.Bool()
+		case int32, int64:
+			return x.Int() < y.Int()
+		case uint32, uint64:
+			return x.Uint() < y.Uint()
+		case string:
+			return x.String() < y.String()
+		default:
+			panic("invalid map key type")
+		}
+	}
+)
diff --git a/internal/order/order_test.go b/internal/order/order_test.go
new file mode 100644
index 0000000..ecf5e18
--- /dev/null
+++ b/internal/order/order_test.go
@@ -0,0 +1,175 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package order
+
+import (
+	"math/rand"
+	"sort"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	"google.golang.org/protobuf/reflect/protoreflect"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+type fieldDesc struct {
+	index      int
+	name       protoreflect.FullName
+	number     protoreflect.FieldNumber
+	extension  bool
+	oneofIndex int // non-zero means within oneof; negative means synthetic
+	pref.FieldDescriptor
+}
+
+func (d fieldDesc) Index() int               { return d.index }
+func (d fieldDesc) Name() pref.Name          { return d.name.Name() }
+func (d fieldDesc) FullName() pref.FullName  { return d.name }
+func (d fieldDesc) Number() pref.FieldNumber { return d.number }
+func (d fieldDesc) IsExtension() bool        { return d.extension }
+func (d fieldDesc) ContainingOneof() pref.OneofDescriptor {
+	switch {
+	case d.oneofIndex < 0:
+		return oneofDesc{index: -d.oneofIndex, synthetic: true}
+	case d.oneofIndex > 0:
+		return oneofDesc{index: +d.oneofIndex, synthetic: false}
+	default:
+		return nil
+	}
+}
+
+type oneofDesc struct {
+	index     int
+	synthetic bool
+	pref.OneofDescriptor
+}
+
+func (d oneofDesc) Index() int        { return d.index }
+func (d oneofDesc) IsSynthetic() bool { return d.synthetic }
+
+func TestFieldOrder(t *testing.T) {
+	tests := []struct {
+		label  string
+		order  FieldOrder
+		fields []fieldDesc
+	}{{
+		label: "LegacyFieldOrder",
+		order: LegacyFieldOrder,
+		fields: []fieldDesc{
+			// Extension fields sorted first by field number.
+			{number: 2, extension: true},
+			{number: 4, extension: true},
+			{number: 100, extension: true},
+			{number: 120, extension: true},
+
+			// Non-extension fields that are not within a oneof
+			// sorted next by field number.
+			{number: 1},
+			{number: 5, oneofIndex: -9}, // synthetic oneof
+			{number: 10},
+			{number: 11, oneofIndex: -10}, // synthetic oneof
+			{number: 12},
+
+			// Non-synthetic oneofs sorted last by index.
+			{number: 13, oneofIndex: 4},
+			{number: 3, oneofIndex: 5},
+			{number: 9, oneofIndex: 5},
+			{number: 7, oneofIndex: 8},
+		},
+	}, {
+		label: "NumberFieldOrder",
+		order: NumberFieldOrder,
+		fields: []fieldDesc{
+			{number: 1, index: 5, name: "c"},
+			{number: 2, index: 2, name: "b"},
+			{number: 3, index: 3, name: "d"},
+			{number: 5, index: 1, name: "a"},
+			{number: 7, index: 7, name: "e"},
+		},
+	}, {
+		label: "IndexNameFieldOrder",
+		order: IndexNameFieldOrder,
+		fields: []fieldDesc{
+			// Non-extension fields sorted first by index.
+			{index: 0, number: 5, name: "c"},
+			{index: 2, number: 2, name: "a"},
+			{index: 4, number: 4, name: "b"},
+			{index: 7, number: 6, name: "d"},
+
+			// Extension fields sorted last by full name.
+			{index: 3, number: 1, name: "d.a", extension: true},
+			{index: 5, number: 3, name: "e", extension: true},
+			{index: 1, number: 7, name: "g", extension: true},
+		},
+	}}
+
+	for _, tt := range tests {
+		t.Run(tt.label, func(t *testing.T) {
+			want := tt.fields
+			got := append([]fieldDesc(nil), want...)
+			for i, j := range rand.Perm(len(got)) {
+				got[i], got[j] = got[j], got[i]
+			}
+			sort.Slice(got, func(i, j int) bool {
+				return tt.order(got[i], got[j])
+			})
+			if diff := cmp.Diff(want, got,
+				cmp.Comparer(func(x, y fieldDesc) bool { return x == y }),
+			); diff != "" {
+				t.Errorf("order mismatch (-want +got):\n%s", diff)
+			}
+		})
+	}
+}
+
+func TestKeyOrder(t *testing.T) {
+	tests := []struct {
+		label string
+		order KeyOrder
+		keys  []interface{}
+	}{{
+		label: "GenericKeyOrder",
+		order: GenericKeyOrder,
+		keys:  []interface{}{false, true},
+	}, {
+		label: "GenericKeyOrder",
+		order: GenericKeyOrder,
+		keys:  []interface{}{int32(-100), int32(-99), int32(-10), int32(-9), int32(-1), int32(0), int32(+1), int32(+9), int32(+10), int32(+99), int32(+100)},
+	}, {
+		label: "GenericKeyOrder",
+		order: GenericKeyOrder,
+		keys:  []interface{}{int64(-100), int64(-99), int64(-10), int64(-9), int64(-1), int64(0), int64(+1), int64(+9), int64(+10), int64(+99), int64(+100)},
+	}, {
+		label: "GenericKeyOrder",
+		order: GenericKeyOrder,
+		keys:  []interface{}{uint32(0), uint32(1), uint32(9), uint32(10), uint32(99), uint32(100)},
+	}, {
+		label: "GenericKeyOrder",
+		order: GenericKeyOrder,
+		keys:  []interface{}{uint64(0), uint64(1), uint64(9), uint64(10), uint64(99), uint64(100)},
+	}, {
+		label: "GenericKeyOrder",
+		order: GenericKeyOrder,
+		keys:  []interface{}{"", "a", "aa", "ab", "ba", "bb", "\u0080", "\u0080\u0081", "\u0082\u0080"},
+	}}
+
+	for _, tt := range tests {
+		t.Run(tt.label, func(t *testing.T) {
+			var got, want []protoreflect.MapKey
+			for _, v := range tt.keys {
+				want = append(want, pref.ValueOf(v).MapKey())
+			}
+			got = append(got, want...)
+			for i, j := range rand.Perm(len(got)) {
+				got[i], got[j] = got[j], got[i]
+			}
+			sort.Slice(got, func(i, j int) bool {
+				return tt.order(got[i], got[j])
+			})
+			if diff := cmp.Diff(want, got, cmp.Transformer("", protoreflect.MapKey.Interface)); diff != "" {
+				t.Errorf("order mismatch (-want +got):\n%s", diff)
+			}
+		})
+	}
+}
diff --git a/internal/order/range.go b/internal/order/range.go
new file mode 100644
index 0000000..c8090e0
--- /dev/null
+++ b/internal/order/range.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package order provides ordered access to messages and maps.
+package order
+
+import (
+	"sort"
+	"sync"
+
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+type messageField struct {
+	fd pref.FieldDescriptor
+	v  pref.Value
+}
+
+var messageFieldPool = sync.Pool{
+	New: func() interface{} { return new([]messageField) },
+}
+
+type (
+	// FieldRnger is an interface for visiting all fields in a message.
+	// The protoreflect.Message type implements this interface.
+	FieldRanger interface{ Range(VisitField) }
+	// VisitField is called everytime a message field is visited.
+	VisitField = func(pref.FieldDescriptor, pref.Value) bool
+)
+
+// RangeFields iterates over the fields of fs according to the specified order.
+func RangeFields(fs FieldRanger, less FieldOrder, fn VisitField) {
+	if less == nil {
+		fs.Range(fn)
+		return
+	}
+
+	// Obtain a pre-allocated scratch buffer.
+	p := messageFieldPool.Get().(*[]messageField)
+	fields := (*p)[:0]
+	defer func() {
+		if cap(fields) < 1024 {
+			*p = fields
+			messageFieldPool.Put(p)
+		}
+	}()
+
+	// Collect all fields in the message and sort them.
+	fs.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		fields = append(fields, messageField{fd, v})
+		return true
+	})
+	sort.Slice(fields, func(i, j int) bool {
+		return less(fields[i].fd, fields[j].fd)
+	})
+
+	// Visit the fields in the specified ordering.
+	for _, f := range fields {
+		if !fn(f.fd, f.v) {
+			return
+		}
+	}
+}
+
+type mapEntry struct {
+	k pref.MapKey
+	v pref.Value
+}
+
+var mapEntryPool = sync.Pool{
+	New: func() interface{} { return new([]mapEntry) },
+}
+
+type (
+	// EntryRanger is an interface for visiting all fields in a message.
+	// The protoreflect.Map type implements this interface.
+	EntryRanger interface{ Range(VisitEntry) }
+	// VisitEntry is called everytime a map entry is visited.
+	VisitEntry = func(pref.MapKey, pref.Value) bool
+)
+
+// RangeEntries iterates over the entries of es according to the specified order.
+func RangeEntries(es EntryRanger, less KeyOrder, fn VisitEntry) {
+	if less == nil {
+		es.Range(fn)
+		return
+	}
+
+	// Obtain a pre-allocated scratch buffer.
+	p := mapEntryPool.Get().(*[]mapEntry)
+	entries := (*p)[:0]
+	defer func() {
+		if cap(entries) < 1024 {
+			*p = entries
+			mapEntryPool.Put(p)
+		}
+	}()
+
+	// Collect all entries in the map and sort them.
+	es.Range(func(k pref.MapKey, v pref.Value) bool {
+		entries = append(entries, mapEntry{k, v})
+		return true
+	})
+	sort.Slice(entries, func(i, j int) bool {
+		return less(entries[i].k, entries[j].k)
+	})
+
+	// Visit the entries in the specified ordering.
+	for _, e := range entries {
+		if !fn(e.k, e.v) {
+			return
+		}
+	}
+}
diff --git a/proto/encode.go b/proto/encode.go
index 7b47a11..d18239c 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -5,12 +5,9 @@
 package proto
 
 import (
-	"sort"
-
 	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/internal/encoding/messageset"
-	"google.golang.org/protobuf/internal/fieldsort"
-	"google.golang.org/protobuf/internal/mapsort"
+	"google.golang.org/protobuf/internal/order"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoiface"
@@ -211,14 +208,15 @@
 	if messageset.IsMessageSet(m.Descriptor()) {
 		return o.marshalMessageSet(b, m)
 	}
-	// There are many choices for what order we visit fields in. The default one here
-	// is chosen for reasonable efficiency and simplicity given the protoreflect API.
-	// It is not deterministic, since Message.Range does not return fields in any
-	// defined order.
-	//
-	// When using deterministic serialization, we sort the known fields.
+	fieldOrder := order.AnyFieldOrder
+	if o.Deterministic {
+		// TODO: This should use a more natural ordering like NumberFieldOrder,
+		// but doing so breaks golden tests that make invalid assumption about
+		// output stability of this implementation.
+		fieldOrder = order.LegacyFieldOrder
+	}
 	var err error
-	o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+	order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 		b, err = o.marshalField(b, fd, v)
 		return err == nil
 	})
@@ -229,27 +227,6 @@
 	return b, nil
 }
 
-// rangeFields visits fields in a defined order when deterministic serialization is enabled.
-func (o MarshalOptions) rangeFields(m protoreflect.Message, f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
-	if !o.Deterministic {
-		m.Range(f)
-		return
-	}
-	var fds []protoreflect.FieldDescriptor
-	m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
-		fds = append(fds, fd)
-		return true
-	})
-	sort.Slice(fds, func(a, b int) bool {
-		return fieldsort.Less(fds[a], fds[b])
-	})
-	for _, fd := range fds {
-		if !f(fd, m.Get(fd)) {
-			break
-		}
-	}
-}
-
 func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
 	switch {
 	case fd.IsList():
@@ -292,8 +269,12 @@
 func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) {
 	keyf := fd.MapKey()
 	valf := fd.MapValue()
+	keyOrder := order.AnyKeyOrder
+	if o.Deterministic {
+		keyOrder = order.GenericKeyOrder
+	}
 	var err error
-	o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool {
+	order.RangeEntries(mapv, keyOrder, func(key protoreflect.MapKey, value protoreflect.Value) bool {
 		b = protowire.AppendTag(b, fd.Number(), protowire.BytesType)
 		var pos int
 		b, pos = appendSpeculativeLength(b)
@@ -312,14 +293,6 @@
 	return b, err
 }
 
-func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
-	if !o.Deterministic {
-		mapv.Range(f)
-		return
-	}
-	mapsort.Range(mapv, kind, f)
-}
-
 // When encoding length-prefixed fields, we speculatively set aside some number of bytes
 // for the length, encode the data, and then encode the length (shifting the data if necessary
 // to make room).
diff --git a/proto/messageset.go b/proto/messageset.go
index 1d692c3..312d5d4 100644
--- a/proto/messageset.go
+++ b/proto/messageset.go
@@ -9,6 +9,7 @@
 	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/flags"
+	"google.golang.org/protobuf/internal/order"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
 )
@@ -28,8 +29,12 @@
 	if !flags.ProtoLegacy {
 		return b, errors.New("no support for message_set_wire_format")
 	}
+	fieldOrder := order.AnyFieldOrder
+	if o.Deterministic {
+		fieldOrder = order.NumberFieldOrder
+	}
 	var err error
-	o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+	order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 		b, err = o.marshalMessageSetField(b, fd, v)
 		return err == nil
 	})