proto: return an error instead of producing invalid wire format

There currently is no risk of producing invalid wire format,
but that will change with subsequent changes regarding lazy decoding.

We have been running this change in production for about a month,
without ever triggering the check (until lazy decoding is involved).

related to golang/protobuf#1609

Change-Id: I3c5c956aee2fa81f99dea03ed2a977a1547081fc
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579595
Auto-Submit: Michael Stapelberg <stapelberg@google.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/errors/errors.go b/internal/errors/errors.go
index 20c17b3..d967198 100644
--- a/internal/errors/errors.go
+++ b/internal/errors/errors.go
@@ -87,3 +87,18 @@
 func RequiredNotSet(name string) error {
 	return New("required field %v not set", name)
 }
+
+type SizeMismatchError struct {
+	Calculated, Measured int
+}
+
+func (e *SizeMismatchError) Error() string {
+	return fmt.Sprintf("size mismatch (see https://github.com/golang/protobuf/issues/1609): calculated=%d, measured=%d", e.Calculated, e.Measured)
+}
+
+func MismatchedSizeCalculation(calculated, measured int) error {
+	return &SizeMismatchError{
+		Calculated: calculated,
+		Measured:   measured,
+	}
+}
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 3fadd24..e8040ef 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -233,9 +233,15 @@
 }
 
 func appendMessageInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+	calculatedSize := f.mi.sizePointer(p.Elem(), opts)
 	b = protowire.AppendVarint(b, f.wiretag)
-	b = protowire.AppendVarint(b, uint64(f.mi.sizePointer(p.Elem(), opts)))
-	return f.mi.marshalAppendPointer(b, p.Elem(), opts)
+	b = protowire.AppendVarint(b, uint64(calculatedSize))
+	before := len(b)
+	b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
+	if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
+		return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
+	}
+	return b, err
 }
 
 func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
@@ -267,9 +273,15 @@
 }
 
 func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
+	calculatedSize := proto.Size(m)
 	b = protowire.AppendVarint(b, wiretag)
-	b = protowire.AppendVarint(b, uint64(proto.Size(m)))
-	return opts.Options().MarshalAppend(b, m)
+	b = protowire.AppendVarint(b, uint64(calculatedSize))
+	before := len(b)
+	b, err := opts.Options().MarshalAppend(b, m)
+	if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
+		return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
+	}
+	return b, err
 }
 
 func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
@@ -482,10 +494,14 @@
 		b = protowire.AppendVarint(b, f.wiretag)
 		siz := f.mi.sizePointer(v, opts)
 		b = protowire.AppendVarint(b, uint64(siz))
+		before := len(b)
 		b, err = f.mi.marshalAppendPointer(b, v, opts)
 		if err != nil {
 			return b, err
 		}
+		if measuredSize := len(b) - before; siz != measuredSize {
+			return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
+		}
 	}
 	return b, nil
 }
@@ -538,10 +554,14 @@
 		b = protowire.AppendVarint(b, wiretag)
 		siz := proto.Size(m)
 		b = protowire.AppendVarint(b, uint64(siz))
+		before := len(b)
 		b, err = opts.Options().MarshalAppend(b, m)
 		if err != nil {
 			return b, err
 		}
+		if measuredSize := len(b) - before; siz != measuredSize {
+			return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
+		}
 	}
 	return b, nil
 }
@@ -599,11 +619,15 @@
 		b = protowire.AppendVarint(b, wiretag)
 		siz := proto.Size(m)
 		b = protowire.AppendVarint(b, uint64(siz))
+		before := len(b)
 		var err error
 		b, err = mopts.MarshalAppend(b, m)
 		if err != nil {
 			return b, err
 		}
+		if measuredSize := len(b) - before; siz != measuredSize {
+			return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
+		}
 	}
 	return b, nil
 }
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 111b9d1..fb35f0b 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -9,6 +9,7 @@
 	"sort"
 
 	"google.golang.org/protobuf/encoding/protowire"
+	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/genid"
 	"google.golang.org/protobuf/reflect/protoreflect"
 )
@@ -240,11 +241,16 @@
 		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 		size += mapi.valFuncs.size(val, mapValTagSize, opts)
 		b = protowire.AppendVarint(b, uint64(size))
+		before := len(b)
 		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
 		if err != nil {
 			return nil, err
 		}
-		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
+		b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
+		if measuredSize := len(b) - before; size != measuredSize && err == nil {
+			return nil, errors.MismatchedSizeCalculation(size, measuredSize)
+		}
+		return b, err
 	} else {
 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 		val := pointerOfValue(valrv)
@@ -259,7 +265,12 @@
 		}
 		b = protowire.AppendVarint(b, mapi.valWiretag)
 		b = protowire.AppendVarint(b, uint64(valSize))
-		return f.mi.marshalAppendPointer(b, val, opts)
+		before := len(b)
+		b, err = f.mi.marshalAppendPointer(b, val, opts)
+		if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
+			return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
+		}
+		return b, err
 	}
 }
 
diff --git a/proto/encode.go b/proto/encode.go
index 5f69539..c779f6d 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -5,12 +5,17 @@
 package proto
 
 import (
+	"errors"
+	"fmt"
+
 	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/order"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoiface"
+
+	protoerrors "google.golang.org/protobuf/internal/errors"
 )
 
 // MarshalOptions configures the marshaler.
@@ -175,6 +180,10 @@
 		out.Buf, err = o.marshalMessageSlow(b, m)
 	}
 	if err != nil {
+		var mismatch *protoerrors.SizeMismatchError
+		if errors.As(err, &mismatch) {
+			return out, fmt.Errorf("marshaling %s: %v", string(m.Descriptor().FullName()), err)
+		}
 		return out, err
 	}
 	if allowPartial {
diff --git a/proto/messageset.go b/proto/messageset.go
index 312d5d4..575d148 100644
--- a/proto/messageset.go
+++ b/proto/messageset.go
@@ -47,11 +47,16 @@
 func (o MarshalOptions) marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
 	b = messageset.AppendFieldStart(b, fd.Number())
 	b = protowire.AppendTag(b, messageset.FieldMessage, protowire.BytesType)
-	b = protowire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
+	calculatedSize := o.Size(value.Message().Interface())
+	b = protowire.AppendVarint(b, uint64(calculatedSize))
+	before := len(b)
 	b, err := o.marshalMessage(b, value.Message())
 	if err != nil {
 		return b, err
 	}
+	if measuredSize := len(b) - before; calculatedSize != measuredSize {
+		return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
+	}
 	b = messageset.AppendFieldEnd(b)
 	return b, nil
 }