ssh/agent: parse constraints when adding keys
Change-Id: I264fc3e3e441d6e5ff7c5aa624eee1018cf9e4de
Reviewed-on: https://go-review.googlesource.com/50811
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/agent/client.go b/ssh/agent/client.go
index ecfd7c5..dce7682 100644
--- a/ssh/agent/client.go
+++ b/ssh/agent/client.go
@@ -57,6 +57,17 @@
Signers() ([]ssh.Signer, error)
}
+// ConstraintExtension describes an optional constraint defined by users.
+type ConstraintExtension struct {
+ // ExtensionName consist of a UTF-8 string suffixed by the
+ // implementation domain following the naming scheme defined
+ // in Section 4.2 of [RFC4251], e.g. "foo@example.com".
+ ExtensionName string
+ // ExtensionDetails contains the actual content of the extended
+ // constraint.
+ ExtensionDetails []byte
+}
+
// AddedKey describes an SSH key to be added to an Agent.
type AddedKey struct {
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
@@ -73,6 +84,9 @@
// ConfirmBeforeUse, if true, requests that the agent confirm with the
// user before each use of this key.
ConfirmBeforeUse bool
+ // ConstraintExtensions are the experimental or private-use constraints
+ // defined by users.
+ ConstraintExtensions []ConstraintExtension
}
// See [PROTOCOL.agent], section 3.
@@ -94,8 +108,9 @@
agentAddSmartcardKeyConstrained = 26
// 3.7 Key constraint identifiers
- agentConstrainLifetime = 1
- agentConstrainConfirm = 2
+ agentConstrainLifetime = 1
+ agentConstrainConfirm = 2
+ agentConstrainExtension = 3
)
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
@@ -151,6 +166,19 @@
Rest []byte `ssh:"rest"`
}
+// 3.7 Key constraint identifiers
+type constrainLifetimeAgentMsg struct {
+ LifetimeSecs uint32 `sshtype:"1"`
+}
+
+type constrainExtensionAgentMsg struct {
+ ExtensionName string `sshtype:"3"`
+ ExtensionDetails []byte
+
+ // Rest is a field used for parsing, not part of message
+ Rest []byte `ssh:"rest"`
+}
+
// Key represents a protocol 2 public key as defined in
// [PROTOCOL.agent], section 2.5.2.
type Key struct {
@@ -542,11 +570,7 @@
var constraints []byte
if secs := key.LifetimeSecs; secs != 0 {
- constraints = append(constraints, agentConstrainLifetime)
-
- var secsBytes [4]byte
- binary.BigEndian.PutUint32(secsBytes[:], secs)
- constraints = append(constraints, secsBytes[:]...)
+ constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
}
if key.ConfirmBeforeUse {
diff --git a/ssh/agent/server.go b/ssh/agent/server.go
index 68a333f..793dd2e 100644
--- a/ssh/agent/server.go
+++ b/ssh/agent/server.go
@@ -155,6 +155,44 @@
return nil, fmt.Errorf("unknown opcode %d", data[0])
}
+func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
+ for len(constraints) != 0 {
+ switch constraints[0] {
+ case agentConstrainLifetime:
+ lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
+ constraints = constraints[5:]
+ case agentConstrainConfirm:
+ confirmBeforeUse = true
+ constraints = constraints[1:]
+ case agentConstrainExtension:
+ var msg constrainExtensionAgentMsg
+ if err = ssh.Unmarshal(constraints, &msg); err != nil {
+ return 0, false, nil, err
+ }
+ extensions = append(extensions, ConstraintExtension{
+ ExtensionName: msg.ExtensionName,
+ ExtensionDetails: msg.ExtensionDetails,
+ })
+ constraints = msg.Rest
+ default:
+ return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
+ }
+ }
+ return
+}
+
+func setConstraints(key *AddedKey, constraintBytes []byte) error {
+ lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
+ if err != nil {
+ return err
+ }
+
+ key.LifetimeSecs = lifetimeSecs
+ key.ConfirmBeforeUse = confirmBeforeUse
+ key.ConstraintExtensions = constraintExtensions
+ return nil
+}
+
func parseRSAKey(req []byte) (*AddedKey, error) {
var k rsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
@@ -173,7 +211,11 @@
}
priv.Precompute()
- return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+ addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func parseEd25519Key(req []byte) (*AddedKey, error) {
@@ -182,7 +224,12 @@
return nil, err
}
priv := ed25519.PrivateKey(k.Priv)
- return &AddedKey{PrivateKey: &priv, Comment: k.Comments}, nil
+
+ addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func parseDSAKey(req []byte) (*AddedKey, error) {
@@ -202,7 +249,11 @@
X: k.X,
}
- return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+ addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
@@ -243,7 +294,12 @@
if !ok {
return nil, errors.New("agent: bad ED25519 certificate")
}
- return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil
+
+ addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func parseECDSAKey(req []byte) (*AddedKey, error) {
@@ -257,7 +313,11 @@
return nil, err
}
- return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+ addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func parseRSACert(req []byte) (*AddedKey, error) {
@@ -300,7 +360,11 @@
}
priv.Precompute()
- return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil
+ addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func parseDSACert(req []byte) (*AddedKey, error) {
@@ -338,7 +402,11 @@
X: k.X,
}
- return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil
+ addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func parseECDSACert(req []byte) (*AddedKey, error) {
@@ -371,7 +439,11 @@
return nil, err
}
- return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil
+ addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
+ if err := setConstraints(addedKey, k.Constraints); err != nil {
+ return nil, err
+ }
+ return addedKey, nil
}
func (s *server) insertIdentity(req []byte) error {
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
index 6b0837d..c34d05b 100644
--- a/ssh/agent/server_test.go
+++ b/ssh/agent/server_test.go
@@ -8,6 +8,9 @@
"crypto"
"crypto/rand"
"fmt"
+ pseudorand "math/rand"
+ "reflect"
+ "strings"
"testing"
"golang.org/x/crypto/ssh"
@@ -207,3 +210,50 @@
}
}
}
+
+func TestParseConstraints(t *testing.T) {
+ // Test LifetimeSecs
+ var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
+ lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
+ if err != nil {
+ t.Fatalf("parseConstraints: %v", err)
+ }
+ if lifetimeSecs != msg.LifetimeSecs {
+ t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
+ }
+
+ // Test ConfirmBeforeUse
+ _, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
+ if err != nil {
+ t.Fatalf("%v", err)
+ }
+ if !confirmBeforeUse {
+ t.Error("got comfirmBeforeUse == false")
+ }
+
+ // Test ConstraintExtensions
+ var data []byte
+ var expect []ConstraintExtension
+ for i := 0; i < 10; i++ {
+ var ext = ConstraintExtension{
+ ExtensionName: fmt.Sprintf("name%d", i),
+ ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
+ }
+ expect = append(expect, ext)
+ data = append(data, agentConstrainExtension)
+ data = append(data, ssh.Marshal(ext)...)
+ }
+ _, _, extensions, err := parseConstraints(data)
+ if err != nil {
+ t.Fatalf("%v", err)
+ }
+ if !reflect.DeepEqual(expect, extensions) {
+ t.Errorf("got extension %v, want %v", extensions, expect)
+ }
+
+ // Test Unknown Constraint
+ _, _, _, err = parseConstraints([]byte{128})
+ if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
+ t.Errorf("unexpected error: %v", err)
+ }
+}