ssh: allow to configure public key auth algorithms on the server side

Fixes golang/go#61244

Change-Id: I29b43e379cf0cdb07b0d6935666491b997157e73
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/510775
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Commit-Queue: Nicola Murino <nicola.murino@gmail.com>
Run-TryBot: Nicola Murino <nicola.murino@gmail.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/common.go b/ssh/common.go
index b419c76..dd2ab0d 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -10,7 +10,6 @@
 	"fmt"
 	"io"
 	"math"
-	"strings"
 	"sync"
 
 	_ "crypto/sha1"
@@ -140,8 +139,6 @@
 	KeyAlgoDSA,
 }
 
-var supportedPubKeyAuthAlgosList = strings.Join(supportedPubKeyAuthAlgos, ",")
-
 // unexpectedMessageError results when the SSH message that we received didn't
 // match what we wanted.
 func unexpectedMessageError(expected, got uint8) error {
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 70a7369..49bbba7 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -11,6 +11,7 @@
 	"io"
 	"log"
 	"net"
+	"strings"
 	"sync"
 )
 
@@ -50,6 +51,10 @@
 	// connection.
 	hostKeys []Signer
 
+	// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
+	// it contains the supported client public key authentication algorithms.
+	publicKeyAuthAlgorithms []string
+
 	// hostKeyAlgorithms is non-empty if we are the client. In that case,
 	// we accept these key types from the server as host key.
 	hostKeyAlgorithms []string
@@ -141,6 +146,7 @@
 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
 	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
 	t.hostKeys = config.hostKeys
+	t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
 	go t.readLoop()
 	go t.kexLoop()
 	return t
@@ -649,6 +655,7 @@
 	// message with the server-sig-algs extension if the client supports it. See
 	// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
 	if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
+		supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
 		extInfo := &extInfoMsg{
 			NumExtensions: 2,
 			Payload:       make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
diff --git a/ssh/server.go b/ssh/server.go
index 727c71b..8f1505a 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -64,6 +64,13 @@
 	// Config contains configuration shared between client and server.
 	Config
 
+	// PublicKeyAuthAlgorithms specifies the supported client public key
+	// authentication algorithms. Note that this should not include certificate
+	// types since those use the underlying algorithm. This list is sent to the
+	// client if it supports the server-sig-algs extension. Order is irrelevant.
+	// If unspecified then a default set of algorithms is used.
+	PublicKeyAuthAlgorithms []string
+
 	hostKeys []Signer
 
 	// NoClientAuth is true if clients are allowed to connect without
@@ -201,6 +208,15 @@
 	if fullConf.MaxAuthTries == 0 {
 		fullConf.MaxAuthTries = 6
 	}
+	if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
+		fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
+	} else {
+		for _, algo := range fullConf.PublicKeyAuthAlgorithms {
+			if !contains(supportedPubKeyAuthAlgos, algo) {
+				return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
+			}
+		}
+	}
 	// Check if the config contains any unsupported key exchanges
 	for _, kex := range fullConf.KeyExchanges {
 		if _, ok := serverForbiddenKexAlgos[kex]; ok {
@@ -524,7 +540,7 @@
 				return nil, parseError(msgUserAuthRequest)
 			}
 			algo := string(algoBytes)
-			if !contains(supportedPubKeyAuthAlgos, underlyingAlgo(algo)) {
+			if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
 				authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
 				break
 			}
@@ -591,7 +607,7 @@
 				// algorithm name that corresponds to algo with
 				// sig.Format.  This is usually the same, but
 				// for certs, the names differ.
-				if !contains(supportedPubKeyAuthAlgos, sig.Format) {
+				if !contains(config.PublicKeyAuthAlgorithms, sig.Format) {
 					authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
 					break
 				}
diff --git a/ssh/server_test.go b/ssh/server_test.go
new file mode 100644
index 0000000..2145dce
--- /dev/null
+++ b/ssh/server_test.go
@@ -0,0 +1,85 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"testing"
+)
+
+func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
+	for _, tt := range []struct {
+		name      string
+		key       Signer
+		wantError bool
+	}{
+		{"rsa", testSigners["rsa"], false},
+		{"dsa", testSigners["dsa"], true},
+		{"ed25519", testSigners["ed25519"], true},
+	} {
+		c1, c2, err := netPipe()
+		if err != nil {
+			t.Fatalf("netPipe: %v", err)
+		}
+		defer c1.Close()
+		defer c2.Close()
+		serverConf := &ServerConfig{
+			PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+				return nil, nil
+			},
+		}
+		serverConf.AddHostKey(testSigners["ecdsap256"])
+
+		done := make(chan struct{})
+		go func() {
+			defer close(done)
+			NewServerConn(c1, serverConf)
+		}()
+
+		clientConf := ClientConfig{
+			User: "user",
+			Auth: []AuthMethod{
+				PublicKeys(tt.key),
+			},
+			HostKeyCallback: InsecureIgnoreHostKey(),
+		}
+
+		_, _, _, err = NewClientConn(c2, "", &clientConf)
+		if err != nil {
+			if !tt.wantError {
+				t.Errorf("%s: got unexpected error %q", tt.name, err.Error())
+			}
+		} else if tt.wantError {
+			t.Errorf("%s: succeeded, but want error", tt.name)
+		}
+		<-done
+	}
+}
+
+func TestNewServerConnValidationErrors(t *testing.T) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	serverConf := &ServerConfig{
+		PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
+	}
+	_, _, _, err = NewServerConn(c1, serverConf)
+	if err == nil {
+		t.Fatal("NewServerConn with invalid public key auth algorithms succeeded")
+	}
+	serverConf = &ServerConfig{
+		Config: Config{
+			KeyExchanges: []string{kexAlgoDHGEXSHA256},
+		},
+	}
+	_, _, _, err = NewServerConn(c1, serverConf)
+	if err == nil {
+		t.Fatal("NewServerConn with unsupported key exchange succeeded")
+	}
+}