internal/impl: inline small tag decoding

Inline varint decoding of small (1- and 2-byte) field tags in the
fast-path unmarshaler.

name                             old time/op  new time/op  delta
EmptyMessage/Wire/Unmarshal      40.6ns ± 1%  40.2ns ± 1%   -1.02%  (p=0.000 n=37+35)
EmptyMessage/Wire/Unmarshal-12   6.77ns ± 2%  7.13ns ± 5%   +5.32%  (p=0.000 n=37+37)
RepeatedInt32/Wire/Unmarshal     9.46µs ± 1%  6.57µs ± 1%  -30.56%  (p=0.000 n=38+39)
RepeatedInt32/Wire/Unmarshal-12  1.50µs ± 2%  1.05µs ± 2%  -30.00%  (p=0.000 n=39+37)
Required/Wire/Unmarshal           371ns ± 1%   258ns ± 1%  -30.44%  (p=0.000 n=38+32)
Required/Wire/Unmarshal-12       60.3ns ± 1%  44.3ns ± 2%  -26.45%  (p=0.000 n=38+36)

Change-Id: Ie80415dea8cb6b840eafa52f0572046a1910a9b1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216419
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/encoding/wire/wire.go b/internal/encoding/wire/wire.go
index 32622b1..d7baa7f 100644
--- a/internal/encoding/wire/wire.go
+++ b/internal/encoding/wire/wire.go
@@ -176,7 +176,6 @@
 
 // AppendVarint appends v to b as a varint-encoded uint64.
 func AppendVarint(b []byte, v uint64) []byte {
-	// TODO: Specialize for sizes 1 and 2 with mid-stack inlining.
 	switch {
 	case v < 1<<7:
 		b = append(b, byte(v))
@@ -259,7 +258,6 @@
 // ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
 // This returns a negative length upon an error (see ParseError).
 func ConsumeVarint(b []byte) (v uint64, n int) {
-	// TODO: Specialize for sizes 1 and 2 with mid-stack inlining.
 	var y uint64
 	if len(b) <= 0 {
 		return 0, errCodeTruncated
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index fc93525..4a1d631 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -97,15 +97,27 @@
 	start := len(b)
 	for len(b) > 0 {
 		// Parse the tag (field number and wire type).
-		// TODO: inline 1 and 2 byte variants?
-		num, wtyp, n := wire.ConsumeTag(b)
-		if n < 0 {
-			return out, wire.ParseError(n)
+		var tag uint64
+		if b[0] < 0x80 {
+			tag = uint64(b[0])
+			b = b[1:]
+		} else if len(b) >= 2 && b[1] < 128 {
+			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
+			b = b[2:]
+		} else {
+			var n int
+			tag, n = wire.ConsumeVarint(b)
+			if n < 0 {
+				return out, wire.ParseError(n)
+			}
+			b = b[n:]
 		}
-		if num > wire.MaxValidNumber {
+		num := wire.Number(tag >> 3)
+		wtyp := wire.Type(tag & 7)
+
+		if num < wire.MinValidNumber || num > wire.MaxValidNumber {
 			return out, errors.New("invalid field number")
 		}
-		b = b[n:]
 
 		if wtyp == wire.EndGroupType {
 			if num != groupTag {
@@ -121,6 +133,7 @@
 		} else {
 			f = mi.coderFields[num]
 		}
+		var n int
 		err := errUnknown
 		switch {
 		case f != nil: