go.crypto/ssh: improve marshal performance
Atom N450, 6g
benchmark old ns/op new ns/op delta
BenchmarkMarshalKexInitMsg 96446 66675 -30.87%
BenchmarkUnmarshalKexInitMsg 155341 142715 -8.13%
BenchmarkMarshalKexDHInitMsg 9499 8340 -12.20%
BenchmarkUnmarshalKexDHInitMsg 4973 5145 +3.46%
Intel E3-1270, 6g
benchmark old ns/op new ns/op delta
BenchmarkMarshalKexInitMsg 23218 16903 -27.20%
BenchmarkUnmarshalKexInitMsg 31384 31640 +0.82%
BenchmarkMarshalKexDHInitMsg 1943 1661 -14.51%
BenchmarkUnmarshalKexDHInitMsg 915 941 +2.84%
R=agl, minux.ma, remyoudompheng
CC=golang-dev
https://golang.org/cl/5728053
diff --git a/ssh/common.go b/ssh/common.go
index 6c65b95..3368c3c 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -249,3 +249,15 @@
}
return string(out)
}
+
+func appendU16(buf []byte, n uint16) []byte {
+ return append(buf, byte(n>>8), byte(n))
+}
+
+func appendU32(buf []byte, n uint32) []byte {
+ return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
+}
+
+func appendInt(buf []byte, n int) []byte {
+ return appendU32(buf, uint32(n))
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index eac8b8d..d61c6c7 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -6,6 +6,7 @@
import (
"bytes"
+ "encoding/binary"
"io"
"math/big"
"reflect"
@@ -220,7 +221,7 @@
if len(packet) < t.Len() {
return ParseError{expectedType}
}
- for j := 0; j < t.Len(); j++ {
+ for j, n := 0, t.Len(); j < n; j++ {
field.Index(j).Set(reflect.ValueOf(packet[j]))
}
packet = packet[t.Len():]
@@ -282,15 +283,13 @@
// marshal serializes the message in msg, using the given message type.
func marshal(msgType uint8, msg interface{}) []byte {
- var out []byte
- out = append(out, msgType)
+ out := make([]byte, 1, 64)
+ out[0] = msgType
v := reflect.ValueOf(msg)
- structType := v.Type()
- for i := 0; i < v.NumField(); i++ {
+ for i, n := 0, v.NumField(); i < n; i++ {
field := v.Field(i)
- t := field.Type()
- switch t.Kind() {
+ switch t := field.Type(); t.Kind() {
case reflect.Bool:
var v uint8
if field.Bool() {
@@ -301,53 +300,35 @@
if t.Elem().Kind() != reflect.Uint8 {
panic("array of non-uint8")
}
- for j := 0; j < t.Len(); j++ {
- out = append(out, byte(field.Index(j).Uint()))
+ for j, l := 0, t.Len(); j < l; j++ {
+ out = append(out, uint8(field.Index(j).Uint()))
}
case reflect.Uint32:
- u32 := uint32(field.Uint())
- out = append(out, byte(u32>>24))
- out = append(out, byte(u32>>16))
- out = append(out, byte(u32>>8))
- out = append(out, byte(u32))
+ out = appendU32(out, uint32(field.Uint()))
case reflect.String:
s := field.String()
- out = append(out, byte(len(s)>>24))
- out = append(out, byte(len(s)>>16))
- out = append(out, byte(len(s)>>8))
- out = append(out, byte(len(s)))
+ out = appendInt(out, len(s))
out = append(out, s...)
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
- length := field.Len()
- if structType.Field(i).Tag.Get("ssh") != "rest" {
- out = append(out, byte(length>>24))
- out = append(out, byte(length>>16))
- out = append(out, byte(length>>8))
- out = append(out, byte(length))
+ if v.Type().Field(i).Tag.Get("ssh") != "rest" {
+ out = appendInt(out, field.Len())
}
- for j := 0; j < length; j++ {
- out = append(out, byte(field.Index(j).Uint()))
- }
+ out = append(out, field.Bytes()...)
case reflect.String:
- var length int
- for j := 0; j < field.Len(); j++ {
- if j != 0 {
- length++ /* comma */
+ offset := len(out)
+ out = appendU32(out, 0)
+ if n := field.Len(); n > 0 {
+ for j := 0; j < n; j++ {
+ f := field.Index(j)
+ if j != 0 {
+ out = append(out, ',')
+ }
+ out = append(out, f.String()...)
}
- length += len(field.Index(j).String())
- }
-
- out = append(out, byte(length>>24))
- out = append(out, byte(length>>16))
- out = append(out, byte(length>>8))
- out = append(out, byte(length))
- for j := 0; j < field.Len(); j++ {
- if j != 0 {
- out = append(out, ',')
- }
- out = append(out, field.Index(j).String()...)
+ // overwrite length value
+ binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
}
default:
panic("slice of unknown type")
@@ -382,7 +363,7 @@
if len(in) < 4 {
return
}
- length := uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
+ length := binary.BigEndian.Uint32(in)
if uint32(len(in)) < 4+length {
return
}
@@ -438,31 +419,18 @@
return
}
-func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
+func parseUint32(in []byte) (uint32, []byte, bool) {
if len(in) < 4 {
- return
+ return 0, nil, false
}
- out = uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
- rest = in[4:]
- ok = true
- return
+ return binary.BigEndian.Uint32(in), in[4:], true
}
-func parseUint64(in []byte) (out uint64, rest []byte, ok bool) {
+func parseUint64(in []byte) (uint64, []byte, bool) {
if len(in) < 8 {
- return
+ return 0, nil, false
}
- out = uint64(in[0])<<56 |
- uint64(in[1])<<48 |
- uint64(in[2])<<40 |
- uint64(in[3])<<32 |
- uint64(in[4])<<24 |
- uint64(in[5])<<16 |
- uint64(in[6])<<8 |
- uint64(in[7])
- rest = in[8:]
- ok = true
- return
+ return binary.BigEndian.Uint64(in), in[8:], true
}
func nameListLength(namelist []string) int {
@@ -502,22 +470,12 @@
}
func marshalUint32(to []byte, n uint32) []byte {
- to[0] = byte(n >> 24)
- to[1] = byte(n >> 16)
- to[2] = byte(n >> 8)
- to[3] = byte(n)
+ binary.BigEndian.PutUint32(to, n)
return to[4:]
}
func marshalUint64(to []byte, n uint64) []byte {
- to[0] = byte(n >> 56)
- to[1] = byte(n >> 48)
- to[2] = byte(n >> 40)
- to[3] = byte(n >> 32)
- to[4] = byte(n >> 24)
- to[5] = byte(n >> 16)
- to[6] = byte(n >> 8)
- to[7] = byte(n)
+ binary.BigEndian.PutUint64(to, n)
return to[8:]
}