| // Copyright 2020 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package buildlet |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "encoding/json" |
| "errors" |
| "net" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "strings" |
| "testing" |
| ) |
| |
| func TestConnectSSHTLS(t *testing.T) { |
| testCases := []struct { |
| desc string |
| authUser string |
| dialer func(context.Context) (net.Conn, error) |
| key string |
| keyPair KeyPair |
| password string |
| user string |
| wantAuthUser string |
| }{ |
| { |
| desc: "tls-without-authuser", |
| authUser: "", |
| key: "key-foo", |
| keyPair: createKeyPair(t), |
| password: "foo", |
| user: "kate", |
| wantAuthUser: "gomote", |
| }, |
| { |
| desc: "tls-with-authuser", |
| authUser: "george", |
| key: "key-foo", |
| keyPair: createKeyPair(t), |
| password: "foo", |
| user: "kate", |
| wantAuthUser: "george", |
| }, |
| { |
| desc: "tls-with-configured-dialer", |
| authUser: "", |
| dialer: func(_ context.Context) (net.Conn, error) { return nil, errors.New("test error") }, |
| key: "key-foo", |
| keyPair: createKeyPair(t), |
| password: "foo", |
| user: "kate", |
| wantAuthUser: "gomote", |
| }, |
| } |
| for _, tc := range testCases { |
| t.Run(tc.desc, func(t *testing.T) { |
| ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user { |
| t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user) |
| } |
| if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key { |
| t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key) |
| } |
| if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); !gotOk || gotAuthUser != tc.wantAuthUser || gotAuthPass != tc.password { |
| t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, true", gotAuthUser, gotAuthPass, gotOk, tc.wantAuthUser, tc.password) |
| } |
| w.WriteHeader(http.StatusSwitchingProtocols) |
| })) |
| cert, err := tls.X509KeyPair([]byte(tc.keyPair.CertPEM), []byte(tc.keyPair.KeyPEM)) |
| if err != nil { |
| t.Fatalf("tls.X509KeyPair([]byte(%q), []byte(%q)) = %v, %q; want no error", tc.keyPair.CertPEM, tc.keyPair.KeyPEM, cert, err) |
| } |
| ts.TLS = &tls.Config{ |
| Certificates: []tls.Certificate{cert}, |
| } |
| ts.StartTLS() |
| defer ts.Close() |
| c := client{ |
| ipPort: strings.TrimPrefix(ts.URL, "https://"), |
| tls: tc.keyPair, |
| password: tc.password, |
| authUser: tc.authUser, |
| dialer: tc.dialer, |
| } |
| gotConn, gotErr := c.ConnectSSH(tc.user, tc.key) |
| if gotErr != nil { |
| t.Fatalf("Client.ConnectSSH(%s, %s) = %v, %v; want no error", tc.user, tc.key, gotConn, gotErr) |
| } |
| }) |
| } |
| } |
| |
| func TestConnectSSHNonTLS(t *testing.T) { |
| testCases := []struct { |
| desc string |
| authUser string |
| basicAuth bool |
| dialer func(context.Context) (net.Conn, error) |
| key string |
| password string |
| user string |
| wantErr bool |
| }{ |
| { |
| desc: "non-tls-without-authuser", |
| authUser: "gomote", |
| basicAuth: false, |
| key: "key-foo", |
| password: "foo", |
| user: "kate", |
| wantErr: false, |
| }, |
| { |
| desc: "non-tls--with-authuser", |
| authUser: "gomote", |
| basicAuth: true, |
| key: "key-foo", |
| password: "foo", |
| user: "kate", |
| wantErr: false, |
| }, |
| { |
| desc: "non-tls-with-configured-dialer", |
| authUser: "gomote", |
| basicAuth: true, |
| dialer: func(context.Context) (net.Conn, error) { |
| return nil, errors.New("test error") |
| }, |
| key: "key-foo", |
| password: "foo", |
| user: "kate", |
| wantErr: true, |
| }, |
| } |
| for _, tc := range testCases { |
| t.Run(tc.desc, func(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user { |
| t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user) |
| } |
| if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key { |
| t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key) |
| } |
| if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); gotOk || gotAuthUser != "" || gotAuthPass != "" { |
| t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, %t", gotAuthUser, gotAuthPass, gotOk, tc.user, tc.password, tc.basicAuth) |
| } |
| w.WriteHeader(http.StatusSwitchingProtocols) |
| })) |
| defer ts.Close() |
| c := client{ |
| ipPort: strings.TrimPrefix(ts.URL, "http://"), |
| password: tc.password, |
| authUser: tc.authUser, |
| dialer: tc.dialer, |
| } |
| gotConn, gotErr := c.ConnectSSH(tc.user, tc.key) |
| if (gotErr != nil) != tc.wantErr { |
| t.Fatalf("Client.ConnectSSH(%q, %q) = %v, %v; want net.Conn, error=%t", tc.user, tc.key, gotConn, gotErr, tc.wantErr) |
| } |
| }) |
| } |
| } |
| |
| func createKeyPair(t *testing.T) KeyPair { |
| kp, err := NewKeyPair() |
| if err != nil { |
| t.Fatalf("NewKeyPair() = %v, %s; want no error", kp, err) |
| } |
| return kp |
| } |
| |
| // Test that Exec returns ErrTimeout upon reaching the context timeout |
| // during command execution, as its documentation promises. |
| func TestExecTimeoutError(t *testing.T) { |
| mux := http.NewServeMux() |
| mux.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) { |
| json.NewEncoder(w).Encode(Status{}) |
| }) |
| mux.HandleFunc("/exec", func(w http.ResponseWriter, req *http.Request) { |
| w.Write([]byte(".")) |
| w.(http.Flusher).Flush() // /exec needs to flush headers right away. |
| <-req.Context().Done() // Simulate that execution hangs, so no more output. |
| }) |
| ts := httptest.NewServer(mux) |
| defer ts.Close() |
| u, err := url.Parse(ts.URL) |
| if err != nil { |
| t.Fatalf("unable to parse http server url %s", err) |
| } |
| cl := NewClient(u.Host, NoKeyPair) |
| defer cl.Close() |
| |
| // Use a custom context that reports context.DeadlineExceeded |
| // after Exec starts command execution. (context.WithTimeout |
| // requires us to select an arbitrary duration, which might |
| // not be long enough or will make the test take too long.) |
| ctx := deadlineOnDemandContext{ |
| Context: context.Background(), |
| done: make(chan struct{}), |
| } |
| _, execErr := cl.Exec(ctx, "./bin/test", ExecOpts{ |
| OnStartExec: func() { close(ctx.done) }, |
| }) |
| if execErr != ErrTimeout { |
| t.Errorf("cl.Exec error = %v; want %v", execErr, ErrTimeout) |
| } |
| } |
| |
| type deadlineOnDemandContext struct { |
| context.Context |
| done chan struct{} |
| } |
| |
| func (c deadlineOnDemandContext) Done() <-chan struct{} { return c.done } |
| func (c deadlineOnDemandContext) Err() error { |
| select { |
| default: |
| return nil |
| case <-c.done: |
| return context.DeadlineExceeded |
| } |
| } |