ssh/agent: parse constraints when adding keys

Change-Id: I264fc3e3e441d6e5ff7c5aa624eee1018cf9e4de
Reviewed-on: https://go-review.googlesource.com/50811
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/agent/client.go b/ssh/agent/client.go
index ecfd7c5..dce7682 100644
--- a/ssh/agent/client.go
+++ b/ssh/agent/client.go
@@ -57,6 +57,17 @@
 	Signers() ([]ssh.Signer, error)
 }
 
+// ConstraintExtension describes an optional constraint defined by users.
+type ConstraintExtension struct {
+	// ExtensionName consist of a UTF-8 string suffixed by the
+	// implementation domain following the naming scheme defined
+	// in Section 4.2 of [RFC4251], e.g.  "foo@example.com".
+	ExtensionName string
+	// ExtensionDetails contains the actual content of the extended
+	// constraint.
+	ExtensionDetails []byte
+}
+
 // AddedKey describes an SSH key to be added to an Agent.
 type AddedKey struct {
 	// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
@@ -73,6 +84,9 @@
 	// ConfirmBeforeUse, if true, requests that the agent confirm with the
 	// user before each use of this key.
 	ConfirmBeforeUse bool
+	// ConstraintExtensions are the experimental or private-use constraints
+	// defined by users.
+	ConstraintExtensions []ConstraintExtension
 }
 
 // See [PROTOCOL.agent], section 3.
@@ -94,8 +108,9 @@
 	agentAddSmartcardKeyConstrained = 26
 
 	// 3.7 Key constraint identifiers
-	agentConstrainLifetime = 1
-	agentConstrainConfirm  = 2
+	agentConstrainLifetime  = 1
+	agentConstrainConfirm   = 2
+	agentConstrainExtension = 3
 )
 
 // maxAgentResponseBytes is the maximum agent reply size that is accepted. This
@@ -151,6 +166,19 @@
 	Rest   []byte `ssh:"rest"`
 }
 
+// 3.7 Key constraint identifiers
+type constrainLifetimeAgentMsg struct {
+	LifetimeSecs uint32 `sshtype:"1"`
+}
+
+type constrainExtensionAgentMsg struct {
+	ExtensionName    string `sshtype:"3"`
+	ExtensionDetails []byte
+
+	// Rest is a field used for parsing, not part of message
+	Rest []byte `ssh:"rest"`
+}
+
 // Key represents a protocol 2 public key as defined in
 // [PROTOCOL.agent], section 2.5.2.
 type Key struct {
@@ -542,11 +570,7 @@
 	var constraints []byte
 
 	if secs := key.LifetimeSecs; secs != 0 {
-		constraints = append(constraints, agentConstrainLifetime)
-
-		var secsBytes [4]byte
-		binary.BigEndian.PutUint32(secsBytes[:], secs)
-		constraints = append(constraints, secsBytes[:]...)
+		constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
 	}
 
 	if key.ConfirmBeforeUse {
diff --git a/ssh/agent/server.go b/ssh/agent/server.go
index 68a333f..793dd2e 100644
--- a/ssh/agent/server.go
+++ b/ssh/agent/server.go
@@ -155,6 +155,44 @@
 	return nil, fmt.Errorf("unknown opcode %d", data[0])
 }
 
+func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
+	for len(constraints) != 0 {
+		switch constraints[0] {
+		case agentConstrainLifetime:
+			lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
+			constraints = constraints[5:]
+		case agentConstrainConfirm:
+			confirmBeforeUse = true
+			constraints = constraints[1:]
+		case agentConstrainExtension:
+			var msg constrainExtensionAgentMsg
+			if err = ssh.Unmarshal(constraints, &msg); err != nil {
+				return 0, false, nil, err
+			}
+			extensions = append(extensions, ConstraintExtension{
+				ExtensionName:    msg.ExtensionName,
+				ExtensionDetails: msg.ExtensionDetails,
+			})
+			constraints = msg.Rest
+		default:
+			return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
+		}
+	}
+	return
+}
+
+func setConstraints(key *AddedKey, constraintBytes []byte) error {
+	lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
+	if err != nil {
+		return err
+	}
+
+	key.LifetimeSecs = lifetimeSecs
+	key.ConfirmBeforeUse = confirmBeforeUse
+	key.ConstraintExtensions = constraintExtensions
+	return nil
+}
+
 func parseRSAKey(req []byte) (*AddedKey, error) {
 	var k rsaKeyMsg
 	if err := ssh.Unmarshal(req, &k); err != nil {
@@ -173,7 +211,11 @@
 	}
 	priv.Precompute()
 
-	return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func parseEd25519Key(req []byte) (*AddedKey, error) {
@@ -182,7 +224,12 @@
 		return nil, err
 	}
 	priv := ed25519.PrivateKey(k.Priv)
-	return &AddedKey{PrivateKey: &priv, Comment: k.Comments}, nil
+
+	addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func parseDSAKey(req []byte) (*AddedKey, error) {
@@ -202,7 +249,11 @@
 		X: k.X,
 	}
 
-	return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
@@ -243,7 +294,12 @@
 	if !ok {
 		return nil, errors.New("agent: bad ED25519 certificate")
 	}
-	return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil
+
+	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func parseECDSAKey(req []byte) (*AddedKey, error) {
@@ -257,7 +313,11 @@
 		return nil, err
 	}
 
-	return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func parseRSACert(req []byte) (*AddedKey, error) {
@@ -300,7 +360,11 @@
 	}
 	priv.Precompute()
 
-	return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil
+	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func parseDSACert(req []byte) (*AddedKey, error) {
@@ -338,7 +402,11 @@
 		X: k.X,
 	}
 
-	return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil
+	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func parseECDSACert(req []byte) (*AddedKey, error) {
@@ -371,7 +439,11 @@
 		return nil, err
 	}
 
-	return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil
+	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
+	if err := setConstraints(addedKey, k.Constraints); err != nil {
+		return nil, err
+	}
+	return addedKey, nil
 }
 
 func (s *server) insertIdentity(req []byte) error {
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
index 6b0837d..c34d05b 100644
--- a/ssh/agent/server_test.go
+++ b/ssh/agent/server_test.go
@@ -8,6 +8,9 @@
 	"crypto"
 	"crypto/rand"
 	"fmt"
+	pseudorand "math/rand"
+	"reflect"
+	"strings"
 	"testing"
 
 	"golang.org/x/crypto/ssh"
@@ -207,3 +210,50 @@
 		}
 	}
 }
+
+func TestParseConstraints(t *testing.T) {
+	// Test LifetimeSecs
+	var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
+	lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
+	if err != nil {
+		t.Fatalf("parseConstraints: %v", err)
+	}
+	if lifetimeSecs != msg.LifetimeSecs {
+		t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
+	}
+
+	// Test ConfirmBeforeUse
+	_, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
+	if err != nil {
+		t.Fatalf("%v", err)
+	}
+	if !confirmBeforeUse {
+		t.Error("got comfirmBeforeUse == false")
+	}
+
+	// Test ConstraintExtensions
+	var data []byte
+	var expect []ConstraintExtension
+	for i := 0; i < 10; i++ {
+		var ext = ConstraintExtension{
+			ExtensionName:    fmt.Sprintf("name%d", i),
+			ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
+		}
+		expect = append(expect, ext)
+		data = append(data, agentConstrainExtension)
+		data = append(data, ssh.Marshal(ext)...)
+	}
+	_, _, extensions, err := parseConstraints(data)
+	if err != nil {
+		t.Fatalf("%v", err)
+	}
+	if !reflect.DeepEqual(expect, extensions) {
+		t.Errorf("got extension %v, want %v", extensions, expect)
+	}
+
+	// Test Unknown Constraint
+	_, _, _, err = parseConstraints([]byte{128})
+	if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
+		t.Errorf("unexpected error: %v", err)
+	}
+}