openpgp/packet: ensure that first partial packet is 512 bytes
This requirement is from RFC 4880 4.2.2.4.
Also simplify the partialLengthWriter loop. The old code worked but
was written in a confusing way, with a loop whose terminating condition
didn't make sense and was never true in practice.
Rewrite it to more clearly do a set of partial writes of decreasing size.
Fixes golang/go#32474
Change-Id: Ia53ceb39a34f1d6f2ea7c60190d52948bb0db59b
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/181121
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
diff --git a/openpgp/packet/packet.go b/openpgp/packet/packet.go
index 5af64c5..9728d61 100644
--- a/openpgp/packet/packet.go
+++ b/openpgp/packet/packet.go
@@ -14,6 +14,7 @@
"crypto/rsa"
"io"
"math/big"
+ "math/bits"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/openpgp/errors"
@@ -100,33 +101,65 @@
type partialLengthWriter struct {
w io.WriteCloser
lengthByte [1]byte
+ sentFirst bool
+ buf []byte
}
+// RFC 4880 4.2.2.4: the first partial length MUST be at least 512 octets long.
+const minFirstPartialWrite = 512
+
func (w *partialLengthWriter) Write(p []byte) (n int, err error) {
- for len(p) > 0 {
- for power := uint(14); power < 32; power-- {
- l := 1 << power
- if len(p) >= l {
- w.lengthByte[0] = 224 + uint8(power)
- _, err = w.w.Write(w.lengthByte[:])
- if err != nil {
- return
- }
- var m int
- m, err = w.w.Write(p[:l])
- n += m
- if err != nil {
- return
- }
- p = p[l:]
- break
+ off := 0
+ if !w.sentFirst {
+ if len(w.buf) > 0 || len(p) < minFirstPartialWrite {
+ off = len(w.buf)
+ w.buf = append(w.buf, p...)
+ if len(w.buf) < minFirstPartialWrite {
+ return len(p), nil
}
+ p = w.buf
+ w.buf = nil
}
+ w.sentFirst = true
}
- return
+
+ power := uint8(30)
+ for len(p) > 0 {
+ l := 1 << power
+ if len(p) < l {
+ power = uint8(bits.Len32(uint32(len(p)))) - 1
+ l = 1 << power
+ }
+ w.lengthByte[0] = 224 + power
+ _, err = w.w.Write(w.lengthByte[:])
+ if err == nil {
+ var m int
+ m, err = w.w.Write(p[:l])
+ n += m
+ }
+ if err != nil {
+ if n < off {
+ return 0, err
+ }
+ return n - off, err
+ }
+ p = p[l:]
+ }
+ return n - off, nil
}
func (w *partialLengthWriter) Close() error {
+ if len(w.buf) > 0 {
+ // In this case we can't send a 512 byte packet.
+ // Just send what we have.
+ p := w.buf
+ w.sentFirst = true
+ w.buf = nil
+ if _, err := w.Write(p); err != nil {
+ return err
+ }
+ }
+
w.lengthByte[0] = 0
_, err := w.w.Write(w.lengthByte[:])
if err != nil {
diff --git a/openpgp/packet/packet_test.go b/openpgp/packet/packet_test.go
index 1dab5c3..63a8387 100644
--- a/openpgp/packet/packet_test.go
+++ b/openpgp/packet/packet_test.go
@@ -232,7 +232,21 @@
t.Errorf("error from write: %s", err)
}
}
- w.Close()
+ if err := w.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // The first packet should be at least 512 bytes.
+ first, err := buf.ReadByte()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if plen := 1 << (first & 0x1f); plen < 512 {
+ t.Errorf("first packet too short: got %d want at least %d", plen, 512)
+ }
+ if err := buf.UnreadByte(); err != nil {
+ t.Fatal(err)
+ }
want := (maxChunkSize * (maxChunkSize + 1)) / 2
copyBuf := bytes.NewBuffer(nil)
@@ -253,3 +267,25 @@
}
}
}
+
+func TestPartialLengthsShortWrite(t *testing.T) {
+ buf := bytes.NewBuffer(nil)
+ w := &partialLengthWriter{
+ w: noOpCloser{buf},
+ }
+ data := bytes.Repeat([]byte("a"), 510)
+ if _, err := w.Write(data); err != nil {
+ t.Fatal(err)
+ }
+ if err := w.Close(); err != nil {
+ t.Fatal(err)
+ }
+ copyBuf := bytes.NewBuffer(nil)
+ r := &partialLengthReader{buf, 0, true}
+ if _, err := io.Copy(copyBuf, r); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(copyBuf.Bytes(), data) {
+ t.Errorf("got %q want %q", buf.Bytes(), data)
+ }
+}