| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package packet |
| |
| import ( |
| "bytes" |
| "encoding/hex" |
| "fmt" |
| "golang.org/x/crypto/openpgp/errors" |
| "io" |
| "testing" |
| ) |
| |
| func TestReadFull(t *testing.T) { |
| var out [4]byte |
| |
| b := bytes.NewBufferString("foo") |
| n, err := readFull(b, out[:3]) |
| if n != 3 || err != nil { |
| t.Errorf("full read failed n:%d err:%s", n, err) |
| } |
| |
| b = bytes.NewBufferString("foo") |
| n, err = readFull(b, out[:4]) |
| if n != 3 || err != io.ErrUnexpectedEOF { |
| t.Errorf("partial read failed n:%d err:%s", n, err) |
| } |
| |
| b = bytes.NewBuffer(nil) |
| n, err = readFull(b, out[:3]) |
| if n != 0 || err != io.ErrUnexpectedEOF { |
| t.Errorf("empty read failed n:%d err:%s", n, err) |
| } |
| } |
| |
| func readerFromHex(s string) io.Reader { |
| data, err := hex.DecodeString(s) |
| if err != nil { |
| panic("readerFromHex: bad input") |
| } |
| return bytes.NewBuffer(data) |
| } |
| |
| var readLengthTests = []struct { |
| hexInput string |
| length int64 |
| isPartial bool |
| err error |
| }{ |
| {"", 0, false, io.ErrUnexpectedEOF}, |
| {"1f", 31, false, nil}, |
| {"c0", 0, false, io.ErrUnexpectedEOF}, |
| {"c101", 256 + 1 + 192, false, nil}, |
| {"e0", 1, true, nil}, |
| {"e1", 2, true, nil}, |
| {"e2", 4, true, nil}, |
| {"ff", 0, false, io.ErrUnexpectedEOF}, |
| {"ff00", 0, false, io.ErrUnexpectedEOF}, |
| {"ff0000", 0, false, io.ErrUnexpectedEOF}, |
| {"ff000000", 0, false, io.ErrUnexpectedEOF}, |
| {"ff00000000", 0, false, nil}, |
| {"ff01020304", 16909060, false, nil}, |
| } |
| |
| func TestReadLength(t *testing.T) { |
| for i, test := range readLengthTests { |
| length, isPartial, err := readLength(readerFromHex(test.hexInput)) |
| if test.err != nil { |
| if err != test.err { |
| t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err) |
| } |
| continue |
| } |
| if err != nil { |
| t.Errorf("%d: unexpected error: %s", i, err) |
| continue |
| } |
| if length != test.length || isPartial != test.isPartial { |
| t.Errorf("%d: bad result got:(%d,%t) want:(%d,%t)", i, length, isPartial, test.length, test.isPartial) |
| } |
| } |
| } |
| |
| var partialLengthReaderTests = []struct { |
| hexInput string |
| err error |
| hexOutput string |
| }{ |
| {"e0", io.ErrUnexpectedEOF, ""}, |
| {"e001", io.ErrUnexpectedEOF, ""}, |
| {"e0010102", nil, "0102"}, |
| {"ff00000000", nil, ""}, |
| {"e10102e1030400", nil, "01020304"}, |
| {"e101", io.ErrUnexpectedEOF, ""}, |
| } |
| |
| func TestPartialLengthReader(t *testing.T) { |
| for i, test := range partialLengthReaderTests { |
| r := &partialLengthReader{readerFromHex(test.hexInput), 0, true} |
| out, err := io.ReadAll(r) |
| if test.err != nil { |
| if err != test.err { |
| t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err) |
| } |
| continue |
| } |
| if err != nil { |
| t.Errorf("%d: unexpected error: %s", i, err) |
| continue |
| } |
| |
| got := fmt.Sprintf("%x", out) |
| if got != test.hexOutput { |
| t.Errorf("%d: got:%s want:%s", i, test.hexOutput, got) |
| } |
| } |
| } |
| |
| var readHeaderTests = []struct { |
| hexInput string |
| structuralError bool |
| unexpectedEOF bool |
| tag int |
| length int64 |
| hexOutput string |
| }{ |
| {"", false, false, 0, 0, ""}, |
| {"7f", true, false, 0, 0, ""}, |
| |
| // Old format headers |
| {"80", false, true, 0, 0, ""}, |
| {"8001", false, true, 0, 1, ""}, |
| {"800102", false, false, 0, 1, "02"}, |
| {"81000102", false, false, 0, 1, "02"}, |
| {"820000000102", false, false, 0, 1, "02"}, |
| {"860000000102", false, false, 1, 1, "02"}, |
| {"83010203", false, false, 0, -1, "010203"}, |
| |
| // New format headers |
| {"c0", false, true, 0, 0, ""}, |
| {"c000", false, false, 0, 0, ""}, |
| {"c00102", false, false, 0, 1, "02"}, |
| {"c0020203", false, false, 0, 2, "0203"}, |
| {"c00202", false, true, 0, 2, ""}, |
| {"c3020203", false, false, 3, 2, "0203"}, |
| } |
| |
| func TestReadHeader(t *testing.T) { |
| for i, test := range readHeaderTests { |
| tag, length, contents, err := readHeader(readerFromHex(test.hexInput)) |
| if test.structuralError { |
| if _, ok := err.(errors.StructuralError); ok { |
| continue |
| } |
| t.Errorf("%d: expected StructuralError, got:%s", i, err) |
| continue |
| } |
| if err != nil { |
| if len(test.hexInput) == 0 && err == io.EOF { |
| continue |
| } |
| if !test.unexpectedEOF || err != io.ErrUnexpectedEOF { |
| t.Errorf("%d: unexpected error from readHeader: %s", i, err) |
| } |
| continue |
| } |
| if int(tag) != test.tag || length != test.length { |
| t.Errorf("%d: got:(%d,%d) want:(%d,%d)", i, int(tag), length, test.tag, test.length) |
| continue |
| } |
| |
| body, err := io.ReadAll(contents) |
| if err != nil { |
| if !test.unexpectedEOF || err != io.ErrUnexpectedEOF { |
| t.Errorf("%d: unexpected error from contents: %s", i, err) |
| } |
| continue |
| } |
| if test.unexpectedEOF { |
| t.Errorf("%d: expected ErrUnexpectedEOF from contents but got no error", i) |
| continue |
| } |
| got := fmt.Sprintf("%x", body) |
| if got != test.hexOutput { |
| t.Errorf("%d: got:%s want:%s", i, got, test.hexOutput) |
| } |
| } |
| } |
| |
| func TestSerializeHeader(t *testing.T) { |
| tag := packetTypePublicKey |
| lengths := []int{0, 1, 2, 64, 192, 193, 8000, 8384, 8385, 10000} |
| |
| for _, length := range lengths { |
| buf := bytes.NewBuffer(nil) |
| serializeHeader(buf, tag, length) |
| tag2, length2, _, err := readHeader(buf) |
| if err != nil { |
| t.Errorf("length %d, err: %s", length, err) |
| } |
| if tag2 != tag { |
| t.Errorf("length %d, tag incorrect (got %d, want %d)", length, tag2, tag) |
| } |
| if int(length2) != length { |
| t.Errorf("length %d, length incorrect (got %d)", length, length2) |
| } |
| } |
| } |
| |
| func TestPartialLengths(t *testing.T) { |
| buf := bytes.NewBuffer(nil) |
| w := new(partialLengthWriter) |
| w.w = noOpCloser{buf} |
| |
| const maxChunkSize = 64 |
| |
| var b [maxChunkSize]byte |
| var n uint8 |
| for l := 1; l <= maxChunkSize; l++ { |
| for i := 0; i < l; i++ { |
| b[i] = n |
| n++ |
| } |
| m, err := w.Write(b[:l]) |
| if m != l { |
| t.Errorf("short write got: %d want: %d", m, l) |
| } |
| if err != nil { |
| t.Errorf("error from write: %s", err) |
| } |
| } |
| if err := w.Close(); err != nil { |
| t.Fatal(err) |
| } |
| |
| // The first packet should be at least 512 bytes. |
| first, err := buf.ReadByte() |
| if err != nil { |
| t.Fatal(err) |
| } |
| if plen := 1 << (first & 0x1f); plen < 512 { |
| t.Errorf("first packet too short: got %d want at least %d", plen, 512) |
| } |
| if err := buf.UnreadByte(); err != nil { |
| t.Fatal(err) |
| } |
| |
| want := (maxChunkSize * (maxChunkSize + 1)) / 2 |
| copyBuf := bytes.NewBuffer(nil) |
| r := &partialLengthReader{buf, 0, true} |
| m, err := io.Copy(copyBuf, r) |
| if m != int64(want) { |
| t.Errorf("short copy got: %d want: %d", m, want) |
| } |
| if err != nil { |
| t.Errorf("error from copy: %s", err) |
| } |
| |
| copyBytes := copyBuf.Bytes() |
| for i := 0; i < want; i++ { |
| if copyBytes[i] != uint8(i) { |
| t.Errorf("bad pattern in copy at %d", i) |
| break |
| } |
| } |
| } |
| |
| func TestPartialLengthsShortWrite(t *testing.T) { |
| buf := bytes.NewBuffer(nil) |
| w := &partialLengthWriter{ |
| w: noOpCloser{buf}, |
| } |
| data := bytes.Repeat([]byte("a"), 510) |
| if _, err := w.Write(data); err != nil { |
| t.Fatal(err) |
| } |
| if err := w.Close(); err != nil { |
| t.Fatal(err) |
| } |
| copyBuf := bytes.NewBuffer(nil) |
| r := &partialLengthReader{buf, 0, true} |
| if _, err := io.Copy(copyBuf, r); err != nil { |
| t.Fatal(err) |
| } |
| if !bytes.Equal(copyBuf.Bytes(), data) { |
| t.Errorf("got %q want %q", buf.Bytes(), data) |
| } |
| } |