| // Copyright 2020 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package buildlet |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "errors" |
| "net" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| ) |
| |
| func TestConnectSSHTLS(t *testing.T) { |
| testCases := []struct { |
| desc string |
| authUser string |
| dialer func(context.Context) (net.Conn, error) |
| key string |
| keyPair KeyPair |
| password string |
| user string |
| wantAuthUser string |
| }{ |
| { |
| desc: "tls-without-authuser", |
| authUser: "", |
| key: "key-foo", |
| keyPair: createKeyPair(t), |
| password: "foo", |
| user: "kate", |
| wantAuthUser: "gomote", |
| }, |
| { |
| desc: "tls-with-authuser", |
| authUser: "george", |
| key: "key-foo", |
| keyPair: createKeyPair(t), |
| password: "foo", |
| user: "kate", |
| wantAuthUser: "george", |
| }, |
| { |
| desc: "tls-with-configured-dialer", |
| authUser: "", |
| dialer: func(_ context.Context) (net.Conn, error) { return nil, errors.New("test error") }, |
| key: "key-foo", |
| keyPair: createKeyPair(t), |
| password: "foo", |
| user: "kate", |
| wantAuthUser: "gomote", |
| }, |
| } |
| for _, tc := range testCases { |
| t.Run(tc.desc, func(t *testing.T) { |
| ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user { |
| t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user) |
| } |
| if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key { |
| t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key) |
| } |
| if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); !gotOk || gotAuthUser != tc.wantAuthUser || gotAuthPass != tc.password { |
| t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, true", gotAuthUser, gotAuthPass, gotOk, tc.wantAuthUser, tc.password) |
| } |
| w.WriteHeader(http.StatusSwitchingProtocols) |
| })) |
| cert, err := tls.X509KeyPair([]byte(tc.keyPair.CertPEM), []byte(tc.keyPair.KeyPEM)) |
| if err != nil { |
| t.Fatalf("tls.X509KeyPair([]byte(%q), []byte(%q)) = %v, %q; want no error", tc.keyPair.CertPEM, tc.keyPair.KeyPEM, cert, err) |
| } |
| ts.TLS = &tls.Config{ |
| Certificates: []tls.Certificate{cert}, |
| } |
| ts.StartTLS() |
| defer ts.Close() |
| c := Client{ |
| ipPort: strings.TrimPrefix(ts.URL, "https://"), |
| tls: tc.keyPair, |
| password: tc.password, |
| authUser: tc.authUser, |
| dialer: tc.dialer, |
| } |
| gotConn, gotErr := c.ConnectSSH(tc.user, tc.key) |
| if gotErr != nil { |
| t.Fatalf("Client.ConnectSSH(%s, %s) = %v, %v; want no error", tc.user, tc.key, gotConn, gotErr) |
| } |
| }) |
| } |
| } |
| |
| func TestConnectSSHNonTLS(t *testing.T) { |
| testCases := []struct { |
| desc string |
| authUser string |
| basicAuth bool |
| dialer func(context.Context) (net.Conn, error) |
| key string |
| password string |
| user string |
| wantErr bool |
| }{ |
| { |
| desc: "non-tls-without-authuser", |
| authUser: "gomote", |
| basicAuth: false, |
| key: "key-foo", |
| password: "foo", |
| user: "kate", |
| wantErr: false, |
| }, |
| { |
| desc: "non-tls--with-authuser", |
| authUser: "gomote", |
| basicAuth: true, |
| key: "key-foo", |
| password: "foo", |
| user: "kate", |
| wantErr: false, |
| }, |
| { |
| desc: "non-tls-with-configured-dialer", |
| authUser: "gomote", |
| basicAuth: true, |
| dialer: func(context.Context) (net.Conn, error) { |
| return nil, errors.New("test error") |
| }, |
| key: "key-foo", |
| password: "foo", |
| user: "kate", |
| wantErr: true, |
| }, |
| } |
| for _, tc := range testCases { |
| t.Run(tc.desc, func(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user { |
| t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user) |
| } |
| if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key { |
| t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key) |
| } |
| if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); gotOk || gotAuthUser != "" || gotAuthPass != "" { |
| t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, %t", gotAuthUser, gotAuthPass, gotOk, tc.user, tc.password, tc.basicAuth) |
| } |
| w.WriteHeader(http.StatusSwitchingProtocols) |
| })) |
| defer ts.Close() |
| c := Client{ |
| ipPort: strings.TrimPrefix(ts.URL, "http://"), |
| password: tc.password, |
| authUser: tc.authUser, |
| dialer: tc.dialer, |
| } |
| gotConn, gotErr := c.ConnectSSH(tc.user, tc.key) |
| if (gotErr != nil) != tc.wantErr { |
| t.Fatalf("Client.ConnectSSH(%q, %q) = %v, %v; want net.Conn, error=%t", tc.user, tc.key, gotConn, gotErr, tc.wantErr) |
| } |
| }) |
| } |
| } |
| |
| func createKeyPair(t *testing.T) KeyPair { |
| kp, err := NewKeyPair() |
| if err != nil { |
| t.Fatalf("NewKeyPair() = %v, %s; want no error", kp, err) |
| } |
| return kp |
| } |