blob: 9315bb638674963996fe9ae2c6c643ac6f68f4e2 [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package buildlet
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestConnectSSHTLS(t *testing.T) {
testCases := []struct {
desc string
authUser string
dialer func(context.Context) (net.Conn, error)
key string
keyPair KeyPair
password string
user string
wantAuthUser string
}{
{
desc: "tls-without-authuser",
authUser: "",
key: "key-foo",
keyPair: createKeyPair(t),
password: "foo",
user: "kate",
wantAuthUser: "gomote",
},
{
desc: "tls-with-authuser",
authUser: "george",
key: "key-foo",
keyPair: createKeyPair(t),
password: "foo",
user: "kate",
wantAuthUser: "george",
},
{
desc: "tls-with-configured-dialer",
authUser: "",
dialer: func(_ context.Context) (net.Conn, error) { return nil, errors.New("test error") },
key: "key-foo",
keyPair: createKeyPair(t),
password: "foo",
user: "kate",
wantAuthUser: "gomote",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user {
t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user)
}
if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key {
t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key)
}
if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); !gotOk || gotAuthUser != tc.wantAuthUser || gotAuthPass != tc.password {
t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, true", gotAuthUser, gotAuthPass, gotOk, tc.wantAuthUser, tc.password)
}
w.WriteHeader(http.StatusSwitchingProtocols)
}))
cert, err := tls.X509KeyPair([]byte(tc.keyPair.CertPEM), []byte(tc.keyPair.KeyPEM))
if err != nil {
t.Fatalf("tls.X509KeyPair([]byte(%q), []byte(%q)) = %v, %q; want no error", tc.keyPair.CertPEM, tc.keyPair.KeyPEM, cert, err)
}
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
defer ts.Close()
c := client{
ipPort: strings.TrimPrefix(ts.URL, "https://"),
tls: tc.keyPair,
password: tc.password,
authUser: tc.authUser,
dialer: tc.dialer,
}
gotConn, gotErr := c.ConnectSSH(tc.user, tc.key)
if gotErr != nil {
t.Fatalf("Client.ConnectSSH(%s, %s) = %v, %v; want no error", tc.user, tc.key, gotConn, gotErr)
}
})
}
}
func TestConnectSSHNonTLS(t *testing.T) {
testCases := []struct {
desc string
authUser string
basicAuth bool
dialer func(context.Context) (net.Conn, error)
key string
password string
user string
wantErr bool
}{
{
desc: "non-tls-without-authuser",
authUser: "gomote",
basicAuth: false,
key: "key-foo",
password: "foo",
user: "kate",
wantErr: false,
},
{
desc: "non-tls--with-authuser",
authUser: "gomote",
basicAuth: true,
key: "key-foo",
password: "foo",
user: "kate",
wantErr: false,
},
{
desc: "non-tls-with-configured-dialer",
authUser: "gomote",
basicAuth: true,
dialer: func(context.Context) (net.Conn, error) {
return nil, errors.New("test error")
},
key: "key-foo",
password: "foo",
user: "kate",
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user {
t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user)
}
if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key {
t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key)
}
if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); gotOk || gotAuthUser != "" || gotAuthPass != "" {
t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, %t", gotAuthUser, gotAuthPass, gotOk, tc.user, tc.password, tc.basicAuth)
}
w.WriteHeader(http.StatusSwitchingProtocols)
}))
defer ts.Close()
c := client{
ipPort: strings.TrimPrefix(ts.URL, "http://"),
password: tc.password,
authUser: tc.authUser,
dialer: tc.dialer,
}
gotConn, gotErr := c.ConnectSSH(tc.user, tc.key)
if (gotErr != nil) != tc.wantErr {
t.Fatalf("Client.ConnectSSH(%q, %q) = %v, %v; want net.Conn, error=%t", tc.user, tc.key, gotConn, gotErr, tc.wantErr)
}
})
}
}
func createKeyPair(t *testing.T) KeyPair {
kp, err := NewKeyPair()
if err != nil {
t.Fatalf("NewKeyPair() = %v, %s; want no error", kp, err)
}
return kp
}
// Test that Exec returns ErrTimeout upon reaching the context timeout
// during command execution, as its documentation promises.
func TestExecTimeoutError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) {
json.NewEncoder(w).Encode(Status{})
})
mux.HandleFunc("/exec", func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("."))
w.(http.Flusher).Flush() // /exec needs to flush headers right away.
<-req.Context().Done() // Simulate that execution hangs, so no more output.
})
ts := httptest.NewServer(mux)
defer ts.Close()
u, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("unable to parse http server url %s", err)
}
cl := NewClient(u.Host, NoKeyPair)
defer cl.Close()
// Use a custom context that reports context.DeadlineExceeded
// after Exec starts command execution. (context.WithTimeout
// requires us to select an arbitrary duration, which might
// not be long enough or will make the test take too long.)
ctx := deadlineOnDemandContext{
Context: context.Background(),
done: make(chan struct{}),
}
_, execErr := cl.Exec(ctx, "./bin/test", ExecOpts{
OnStartExec: func() { close(ctx.done) },
})
if execErr != ErrTimeout {
t.Errorf("cl.Exec error = %v; want %v", execErr, ErrTimeout)
}
}
type deadlineOnDemandContext struct {
context.Context
done chan struct{}
}
func (c deadlineOnDemandContext) Done() <-chan struct{} { return c.done }
func (c deadlineOnDemandContext) Err() error {
select {
default:
return nil
case <-c.done:
return context.DeadlineExceeded
}
}