| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package ssh |
| |
| import ( |
| "bytes" |
| "math/big" |
| "math/rand" |
| "reflect" |
| "testing" |
| "testing/quick" |
| ) |
| |
| var intLengthTests = []struct { |
| val, length int |
| }{ |
| {0, 4 + 0}, |
| {1, 4 + 1}, |
| {127, 4 + 1}, |
| {128, 4 + 2}, |
| {-1, 4 + 1}, |
| } |
| |
| func TestIntLength(t *testing.T) { |
| for _, test := range intLengthTests { |
| v := new(big.Int).SetInt64(int64(test.val)) |
| length := intLength(v) |
| if length != test.length { |
| t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) |
| } |
| } |
| } |
| |
| type msgAllTypes struct { |
| Bool bool `sshtype:"21"` |
| Array [16]byte |
| Uint64 uint64 |
| Uint32 uint32 |
| Uint8 uint8 |
| String string |
| Strings []string |
| Bytes []byte |
| Int *big.Int |
| Rest []byte `ssh:"rest"` |
| } |
| |
| func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { |
| m := &msgAllTypes{} |
| m.Bool = rand.Intn(2) == 1 |
| randomBytes(m.Array[:], rand) |
| m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) |
| m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) |
| m.Uint8 = uint8(rand.Intn(1 << 8)) |
| m.String = string(m.Array[:]) |
| m.Strings = randomNameList(rand) |
| m.Bytes = m.Array[:] |
| m.Int = randomInt(rand) |
| m.Rest = m.Array[:] |
| return reflect.ValueOf(m) |
| } |
| |
| func TestMarshalUnmarshal(t *testing.T) { |
| rand := rand.New(rand.NewSource(0)) |
| iface := &msgAllTypes{} |
| ty := reflect.ValueOf(iface).Type() |
| |
| n := 100 |
| if testing.Short() { |
| n = 5 |
| } |
| for j := 0; j < n; j++ { |
| v, ok := quick.Value(ty, rand) |
| if !ok { |
| t.Errorf("failed to create value") |
| break |
| } |
| |
| m1 := v.Elem().Interface() |
| m2 := iface |
| |
| marshaled := Marshal(m1) |
| if err := Unmarshal(marshaled, m2); err != nil { |
| t.Errorf("Unmarshal %#v: %s", m1, err) |
| break |
| } |
| |
| if !reflect.DeepEqual(v.Interface(), m2) { |
| t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) |
| break |
| } |
| } |
| } |
| |
| func TestUnmarshalEmptyPacket(t *testing.T) { |
| var b []byte |
| var m channelRequestSuccessMsg |
| if err := Unmarshal(b, &m); err == nil { |
| t.Fatalf("unmarshal of empty slice succeeded") |
| } |
| } |
| |
| func TestUnmarshalUnexpectedPacket(t *testing.T) { |
| type S struct { |
| I uint32 `sshtype:"43"` |
| S string |
| B bool |
| } |
| |
| s := S{11, "hello", true} |
| packet := Marshal(s) |
| packet[0] = 42 |
| roundtrip := S{} |
| err := Unmarshal(packet, &roundtrip) |
| if err == nil { |
| t.Fatal("expected error, not nil") |
| } |
| } |
| |
| func TestMarshalPtr(t *testing.T) { |
| s := struct { |
| S string |
| }{"hello"} |
| |
| m1 := Marshal(s) |
| m2 := Marshal(&s) |
| if !bytes.Equal(m1, m2) { |
| t.Errorf("got %q, want %q for marshaled pointer", m2, m1) |
| } |
| } |
| |
| func TestBareMarshalUnmarshal(t *testing.T) { |
| type S struct { |
| I uint32 |
| S string |
| B bool |
| } |
| |
| s := S{42, "hello", true} |
| packet := Marshal(s) |
| roundtrip := S{} |
| Unmarshal(packet, &roundtrip) |
| |
| if !reflect.DeepEqual(s, roundtrip) { |
| t.Errorf("got %#v, want %#v", roundtrip, s) |
| } |
| } |
| |
| func TestBareMarshal(t *testing.T) { |
| type S2 struct { |
| I uint32 |
| } |
| s := S2{42} |
| packet := Marshal(s) |
| i, rest, ok := parseUint32(packet) |
| if len(rest) > 0 || !ok { |
| t.Errorf("parseInt(%q): parse error", packet) |
| } |
| if i != s.I { |
| t.Errorf("got %d, want %d", i, s.I) |
| } |
| } |
| |
| func TestUnmarshalShortKexInitPacket(t *testing.T) { |
| // This used to panic. |
| // Issue 11348 |
| packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} |
| kim := &kexInitMsg{} |
| if err := Unmarshal(packet, kim); err == nil { |
| t.Error("truncated packet unmarshaled without error") |
| } |
| } |
| |
| func TestMarshalMultiTag(t *testing.T) { |
| var res struct { |
| A uint32 `sshtype:"1|2"` |
| } |
| |
| good1 := struct { |
| A uint32 `sshtype:"1"` |
| }{ |
| 1, |
| } |
| good2 := struct { |
| A uint32 `sshtype:"2"` |
| }{ |
| 1, |
| } |
| |
| if e := Unmarshal(Marshal(good1), &res); e != nil { |
| t.Errorf("error unmarshaling multipart tag: %v", e) |
| } |
| |
| if e := Unmarshal(Marshal(good2), &res); e != nil { |
| t.Errorf("error unmarshaling multipart tag: %v", e) |
| } |
| |
| bad1 := struct { |
| A uint32 `sshtype:"3"` |
| }{ |
| 1, |
| } |
| if e := Unmarshal(Marshal(bad1), &res); e == nil { |
| t.Errorf("bad struct unmarshaled without error") |
| } |
| } |
| |
| func randomBytes(out []byte, rand *rand.Rand) { |
| for i := 0; i < len(out); i++ { |
| out[i] = byte(rand.Int31()) |
| } |
| } |
| |
| func randomNameList(rand *rand.Rand) []string { |
| ret := make([]string, rand.Int31()&15) |
| for i := range ret { |
| s := make([]byte, 1+(rand.Int31()&15)) |
| for j := range s { |
| s[j] = 'a' + uint8(rand.Int31()&15) |
| } |
| ret[i] = string(s) |
| } |
| return ret |
| } |
| |
| func randomInt(rand *rand.Rand) *big.Int { |
| return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) |
| } |
| |
| func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { |
| ki := &kexInitMsg{} |
| randomBytes(ki.Cookie[:], rand) |
| ki.KexAlgos = randomNameList(rand) |
| ki.ServerHostKeyAlgos = randomNameList(rand) |
| ki.CiphersClientServer = randomNameList(rand) |
| ki.CiphersServerClient = randomNameList(rand) |
| ki.MACsClientServer = randomNameList(rand) |
| ki.MACsServerClient = randomNameList(rand) |
| ki.CompressionClientServer = randomNameList(rand) |
| ki.CompressionServerClient = randomNameList(rand) |
| ki.LanguagesClientServer = randomNameList(rand) |
| ki.LanguagesServerClient = randomNameList(rand) |
| if rand.Int31()&1 == 1 { |
| ki.FirstKexFollows = true |
| } |
| return reflect.ValueOf(ki) |
| } |
| |
| func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { |
| dhi := &kexDHInitMsg{} |
| dhi.X = randomInt(rand) |
| return reflect.ValueOf(dhi) |
| } |
| |
| var ( |
| _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() |
| _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() |
| |
| _kexInit = Marshal(_kexInitMsg) |
| _kexDHInit = Marshal(_kexDHInitMsg) |
| ) |
| |
| func BenchmarkMarshalKexInitMsg(b *testing.B) { |
| for i := 0; i < b.N; i++ { |
| Marshal(_kexInitMsg) |
| } |
| } |
| |
| func BenchmarkUnmarshalKexInitMsg(b *testing.B) { |
| m := new(kexInitMsg) |
| for i := 0; i < b.N; i++ { |
| Unmarshal(_kexInit, m) |
| } |
| } |
| |
| func BenchmarkMarshalKexDHInitMsg(b *testing.B) { |
| for i := 0; i < b.N; i++ { |
| Marshal(_kexDHInitMsg) |
| } |
| } |
| |
| func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { |
| m := new(kexDHInitMsg) |
| for i := 0; i < b.N; i++ { |
| Unmarshal(_kexDHInit, m) |
| } |
| } |