blob: 3266e98c532996ff19cbac05dac396cf5f1a5860 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rendezvous
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"golang.org/x/build/revdial/v2"
)
func TestNew(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_ = New(ctx)
}
func TestPurgeExpiredRegistrations(t *testing.T) {
rdv := &Rendezvous{
m: make(map[string]*entry),
}
rdv.m["test"] = &entry{
deadline: time.Unix(0, 0),
ch: make(chan *result, 1),
}
rdv.purgeExpiredRegistrations()
if len(rdv.m) != 0 {
t.Errorf("purgeExpiredRegistrations() did not purge expired entries: want 0 got %d", len(rdv.m))
}
}
func TestRegisterInstance(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rdv := New(ctx)
rdv.RegisterInstance(ctx, "sample-1", time.Minute)
if len(rdv.m) != 1 {
t.Errorf("RegisterInstance: want 1, got %d", len(rdv.m))
}
}
func TestWaitForInstanceError(t *testing.T) {
testCases := []struct {
desc string
headers map[string]string
wantStatusCode int
}{
{desc: "missing host header", headers: map[string]string{HeaderID: "test-id", HeaderToken: "test-token"}, wantStatusCode: 400},
{desc: "missing id header", headers: map[string]string{HeaderToken: "test-token", HeaderHostname: "test-hostname"}, wantStatusCode: 400},
{desc: "missing auth token", headers: map[string]string{HeaderID: "test-id", HeaderHostname: "test-hostname"}, wantStatusCode: 400},
{desc: "missing registration", headers: map[string]string{HeaderID: "test-id", HeaderToken: "test-token", HeaderHostname: "test-hostname"}, wantStatusCode: 412},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
rdv := &Rendezvous{
m: make(map[string]*entry),
validator: func(ctx context.Context, jwt string) bool {
return true
},
}
ts := httptest.NewTLSServer(http.HandlerFunc(rdv.HandleReverse))
defer ts.Close()
client := ts.Client()
req, err := http.NewRequest("GET", ts.URL, nil)
for k, v := range tc.headers {
req.Header.Set(k, v)
}
resp, err := client.Do(req)
if err != nil {
t.Errorf("client.Get(%s): %s", ts.URL, err)
}
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("resp.StatusCode: got %d, want %d", resp.StatusCode, tc.wantStatusCode)
}
})
}
}
func TestWaitForInstaceErrorNonTLS(t *testing.T) {
rdv := &Rendezvous{
m: make(map[string]*entry),
validator: func(ctx context.Context, jwt string) bool {
return true
},
}
ts := httptest.NewServer(http.HandlerFunc(rdv.HandleReverse))
defer ts.Close()
client := ts.Client()
req, err := http.NewRequest("GET", ts.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Errorf("client.Get(%s): %s", ts.URL, err)
}
if resp.StatusCode != 500 {
t.Fatalf("resp.StatusCode: got %d, want %d", resp.StatusCode, 500)
}
}
func TestWaitForInstaceRevdialError(t *testing.T) {
rdv := &Rendezvous{
m: make(map[string]*entry),
validator: func(ctx context.Context, jwt string) bool {
return true
},
}
instanceID := "test-id-3"
ctx := context.Background()
rdv.RegisterInstance(ctx, instanceID, 15*time.Second)
mux := http.NewServeMux()
mux.HandleFunc("/reverse", rdv.HandleReverse)
mux.Handle("/revdial", revdial.ConnHandler())
ts := httptest.NewTLSServer(mux)
defer ts.Close()
client := ts.Client()
req, err := http.NewRequest("GET", ts.URL+"/reverse", nil)
req.Header.Set(HeaderID, instanceID)
req.Header.Set(HeaderToken, "test-token")
req.Header.Set(HeaderHostname, "test-hostname")
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_, _ = client.Do(req)
}()
_, err = rdv.WaitForInstance(ctx, instanceID)
if err == nil {
// expect a missing status endpoint
t.Fatal("WaitForInstance(): got nil, want error")
}
wg.Wait()
}
func TestDeregisterInstance(t *testing.T) {
rdv := &Rendezvous{
m: make(map[string]*entry),
}
id := "test-xyz"
rdv.m[id] = &entry{}
rdv.DeregisterInstance(context.Background(), id)
if len(rdv.m) != 0 {
t.Errorf("/deregusterInstance() did not remove the entry: want 0 got %d", len(rdv.m))
}
}