go.crypto/ssh: prevent concurrent reads and concurrent writes over the same agent connection
minor fix for v01 cert parsing when algo is not supported
R=golang-dev, agl, dave
CC=golang-dev
https://golang.org/cl/6116052
diff --git a/ssh/agent.go b/ssh/agent.go
index 3c701b7..28d52df 100644
--- a/ssh/agent.go
+++ b/ssh/agent.go
@@ -11,6 +11,7 @@
"encoding/base64"
"errors"
"io"
+ "sync"
)
// See PROTOCOL.agent, section 3.
@@ -126,24 +127,37 @@
}
// AgentClient provides a means to communicate with an ssh agent process based
-// on the protocol described in PROTOCOL.agent?rev=1.6. It contains an
-// embedded io.ReadWriter that is typically represented by using a *net.UnixConn.
+// on the protocol described in PROTOCOL.agent?rev=1.6.
type AgentClient struct {
- io.ReadWriter
+ // conn is typically represented by using a *net.UnixConn
+ conn io.ReadWriter
+ // mu is used to prevent concurrent access to the agent
+ mu sync.Mutex
+}
+
+// NewAgentClient creates and returns a new *AgentClient using the
+// passed in io.ReadWriter as a connection to a ssh agent.
+func NewAgentClient(rw io.ReadWriter) *AgentClient {
+ return &AgentClient{conn: rw}
}
// sendAndReceive sends req to the agent and waits for a reply. On success,
// the reply is unmarshaled into reply and replyType is set to the first byte of
// the reply, which contains the type of the message.
func (ac *AgentClient) sendAndReceive(req []byte) (reply interface{}, replyType uint8, err error) {
+ // ac.mu prevents multiple, concurrent requests. Since the agent is typically
+ // on the same machine, we don't attempt to pipeline the requests.
+ ac.mu.Lock()
+ defer ac.mu.Unlock()
+
msg := make([]byte, stringLength(len(req)))
marshalString(msg, req)
- if _, err = ac.Write(msg); err != nil {
+ if _, err = ac.conn.Write(msg); err != nil {
return
}
var respSizeBuf [4]byte
- if _, err = io.ReadFull(ac, respSizeBuf[:]); err != nil {
+ if _, err = io.ReadFull(ac.conn, respSizeBuf[:]); err != nil {
return
}
respSize, _, _ := parseUint32(respSizeBuf[:])
@@ -154,7 +168,7 @@
}
buf := make([]byte, respSize)
- if _, err = io.ReadFull(ac, buf); err != nil {
+ if _, err = io.ReadFull(ac.conn, buf); err != nil {
return
}
return unmarshalAgentMsg(buf)
diff --git a/ssh/certs.go b/ssh/certs.go
index 107fd1a..40cf706 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -78,6 +78,7 @@
}
cert.Key = dsaPubKey
default:
+ ok = false
return
}