openpgp/packet: ensure that first partial packet is 512 bytes

This requirement is from RFC 4880 4.2.2.4.

Also simplify the partialLengthWriter loop. The old code worked but
was written in a confusing way, with a loop whose terminating condition
didn't make sense and was never true in practice.
Rewrite it to more clearly do a set of partial writes of decreasing size.

Fixes golang/go#32474

Change-Id: Ia53ceb39a34f1d6f2ea7c60190d52948bb0db59b
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/181121
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
diff --git a/openpgp/packet/packet.go b/openpgp/packet/packet.go
index 5af64c5..9728d61 100644
--- a/openpgp/packet/packet.go
+++ b/openpgp/packet/packet.go
@@ -14,6 +14,7 @@
 	"crypto/rsa"
 	"io"
 	"math/big"
+	"math/bits"
 
 	"golang.org/x/crypto/cast5"
 	"golang.org/x/crypto/openpgp/errors"
@@ -100,33 +101,65 @@
 type partialLengthWriter struct {
 	w          io.WriteCloser
 	lengthByte [1]byte
+	sentFirst  bool
+	buf        []byte
 }
 
+// RFC 4880 4.2.2.4: the first partial length MUST be at least 512 octets long.
+const minFirstPartialWrite = 512
+
 func (w *partialLengthWriter) Write(p []byte) (n int, err error) {
-	for len(p) > 0 {
-		for power := uint(14); power < 32; power-- {
-			l := 1 << power
-			if len(p) >= l {
-				w.lengthByte[0] = 224 + uint8(power)
-				_, err = w.w.Write(w.lengthByte[:])
-				if err != nil {
-					return
-				}
-				var m int
-				m, err = w.w.Write(p[:l])
-				n += m
-				if err != nil {
-					return
-				}
-				p = p[l:]
-				break
+	off := 0
+	if !w.sentFirst {
+		if len(w.buf) > 0 || len(p) < minFirstPartialWrite {
+			off = len(w.buf)
+			w.buf = append(w.buf, p...)
+			if len(w.buf) < minFirstPartialWrite {
+				return len(p), nil
 			}
+			p = w.buf
+			w.buf = nil
 		}
+		w.sentFirst = true
 	}
-	return
+
+	power := uint8(30)
+	for len(p) > 0 {
+		l := 1 << power
+		if len(p) < l {
+			power = uint8(bits.Len32(uint32(len(p)))) - 1
+			l = 1 << power
+		}
+		w.lengthByte[0] = 224 + power
+		_, err = w.w.Write(w.lengthByte[:])
+		if err == nil {
+			var m int
+			m, err = w.w.Write(p[:l])
+			n += m
+		}
+		if err != nil {
+			if n < off {
+				return 0, err
+			}
+			return n - off, err
+		}
+		p = p[l:]
+	}
+	return n - off, nil
 }
 
 func (w *partialLengthWriter) Close() error {
+	if len(w.buf) > 0 {
+		// In this case we can't send a 512 byte packet.
+		// Just send what we have.
+		p := w.buf
+		w.sentFirst = true
+		w.buf = nil
+		if _, err := w.Write(p); err != nil {
+			return err
+		}
+	}
+
 	w.lengthByte[0] = 0
 	_, err := w.w.Write(w.lengthByte[:])
 	if err != nil {
diff --git a/openpgp/packet/packet_test.go b/openpgp/packet/packet_test.go
index 1dab5c3..63a8387 100644
--- a/openpgp/packet/packet_test.go
+++ b/openpgp/packet/packet_test.go
@@ -232,7 +232,21 @@
 			t.Errorf("error from write: %s", err)
 		}
 	}
-	w.Close()
+	if err := w.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	// The first packet should be at least 512 bytes.
+	first, err := buf.ReadByte()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if plen := 1 << (first & 0x1f); plen < 512 {
+		t.Errorf("first packet too short: got %d want at least %d", plen, 512)
+	}
+	if err := buf.UnreadByte(); err != nil {
+		t.Fatal(err)
+	}
 
 	want := (maxChunkSize * (maxChunkSize + 1)) / 2
 	copyBuf := bytes.NewBuffer(nil)
@@ -253,3 +267,25 @@
 		}
 	}
 }
+
+func TestPartialLengthsShortWrite(t *testing.T) {
+	buf := bytes.NewBuffer(nil)
+	w := &partialLengthWriter{
+		w: noOpCloser{buf},
+	}
+	data := bytes.Repeat([]byte("a"), 510)
+	if _, err := w.Write(data); err != nil {
+		t.Fatal(err)
+	}
+	if err := w.Close(); err != nil {
+		t.Fatal(err)
+	}
+	copyBuf := bytes.NewBuffer(nil)
+	r := &partialLengthReader{buf, 0, true}
+	if _, err := io.Copy(copyBuf, r); err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(copyBuf.Bytes(), data) {
+		t.Errorf("got %q want %q", buf.Bytes(), data)
+	}
+}