kubernetes: cleanup, and add DialServicePort

And moves Dial methods to *kubernetes.Client, instead of a separate
Dialer type.

Updates golang/go#18817

Change-Id: I045ac48441b9139cb0e01ceb4969b29cf5e72507
Reviewed-on: https://go-review.googlesource.com/36692
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/kubernetes/client.go b/kubernetes/client.go
index 80604d5..97d5cdd 100644
--- a/kubernetes/client.go
+++ b/kubernetes/client.go
@@ -10,10 +10,12 @@
 	"bytes"
 	"encoding/json"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"log"
 	"net/http"
 	"net/url"
+	"os"
 	"strings"
 	"time"
 
@@ -22,18 +24,15 @@
 	"golang.org/x/net/context/ctxhttp"
 )
 
-const (
-	// APIEndpoint defines the base path for kubernetes API resources.
-	APIEndpoint     = "/api/v1"
-	defaultPod      = "/namespaces/default/pods"
-	defaultWatchPod = "/watch/namespaces/default/pods"
-	nodes           = "/nodes"
-)
-
 // Client is a client for the Kubernetes master.
 type Client struct {
+	httpClient *http.Client
+
+	// endPointURL is the Kubernetes master URL ending in
+	// "/api/v1".
 	endpointURL string
-	httpClient  *http.Client
+
+	namespace string // always in URL path-escaped form (for now)
 }
 
 // NewClient returns a new Kubernetes client.
@@ -46,8 +45,9 @@
 		return nil, fmt.Errorf("failed to parse URL %q: %v", baseURL, err)
 	}
 	return &Client{
-		endpointURL: strings.TrimSuffix(validURL.String(), "/") + APIEndpoint,
+		endpointURL: strings.TrimSuffix(validURL.String(), "/") + "/api/v1",
 		httpClient:  client,
+		namespace:   "default",
 	}, nil
 }
 
@@ -59,6 +59,12 @@
 	return nil
 }
 
+// nsEndpoint returns the API endpoint root for this client.
+// (This has nothing to do with Service Endpoints.)
+func (c *Client) nsEndpoint() string {
+	return c.endpointURL + "/namespaces/" + c.namespace + "/"
+}
+
 // RunLongLivedPod creates a new pod resource in the default pod namespace with
 // the given pod API specification. It assumes the pod runs a
 // long-lived server (i.e. if the container exit quickly quickly, even
@@ -72,7 +78,7 @@
 	if err := json.NewEncoder(&podJSON).Encode(pod); err != nil {
 		return nil, fmt.Errorf("failed to encode pod in json: %v", err)
 	}
-	postURL := c.endpointURL + defaultPod
+	postURL := c.nsEndpoint() + "pods"
 	req, err := http.NewRequest("POST", postURL, &podJSON)
 	if err != nil {
 		return nil, fmt.Errorf("failed to create request: POST %q : %v", postURL, err)
@@ -122,39 +128,90 @@
 	}
 }
 
-// GetPods returns all pods in the cluster, regardless of status.
-func (c *Client) GetPods(ctx context.Context) ([]api.Pod, error) {
-	getURL := c.endpointURL + defaultPod
-
-	// Make request to Kubernetes API
-	req, err := http.NewRequest("GET", getURL, nil)
+func (c *Client) do(ctx context.Context, method, urlStr string, dst interface{}) error {
+	req, err := http.NewRequest(method, urlStr, nil)
 	if err != nil {
-		return nil, fmt.Errorf("failed to create request: GET %q : %v", getURL, err)
+		return err
 	}
 	res, err := ctxhttp.Do(ctx, c.httpClient, req)
 	if err != nil {
-		return nil, fmt.Errorf("failed to make request: GET %q: %v", getURL, err)
+		return err
 	}
-
-	body, err := ioutil.ReadAll(res.Body)
-	res.Body.Close()
-	if err != nil {
-		return nil, fmt.Errorf("failed to read request body for GET %q: %v", getURL, err)
-	}
+	defer res.Body.Close()
 	if res.StatusCode != http.StatusOK {
-		return nil, fmt.Errorf("http error %d GET %q: %q: %v", res.StatusCode, getURL, string(body), err)
+		body, _ := ioutil.ReadAll(res.Body)
+		return fmt.Errorf("%v %s: %v, %s", method, urlStr, res.Status, body)
 	}
+	if dst != nil {
+		var r io.Reader = res.Body
+		if false && strings.Contains(urlStr, "endpoints") { // for debugging
+			r = io.TeeReader(r, os.Stderr)
+		}
+		return json.NewDecoder(r).Decode(dst)
+	}
+	return nil
+}
 
-	var podList api.PodList
-	if err := json.Unmarshal(body, &podList); err != nil {
-		return nil, fmt.Errorf("failed to decode list of pod resources: %v", err)
+// GetServices return all services in the cluster, regardless of status.
+func (c *Client) GetServices(ctx context.Context) ([]api.Service, error) {
+	var list api.ServiceList
+	if err := c.do(ctx, "GET", c.nsEndpoint()+"services", &list); err != nil {
+		return nil, err
 	}
-	return podList.Items, nil
+	return list.Items, nil
+}
+
+// Endpoint represents a service endpoint address.
+type Endpoint struct {
+	IP       string
+	Port     int
+	PortName string
+	Protocol string // "TCP" or "UDP"; never empty
+}
+
+// GetServiceEndpoints returns the endpoints for the named service.
+// If portName is non-empty, only endpoints matching that port nae are returned.
+func (c *Client) GetServiceEndpoints(ctx context.Context, serviceName, portName string) ([]Endpoint, error) {
+	var res api.Endpoints
+	// TODO: path escape serviceName?
+	if err := c.do(ctx, "GET", c.nsEndpoint()+"endpoints/"+serviceName, &res); err != nil {
+		return nil, err
+	}
+	var ep []Endpoint
+	for _, ss := range res.Subsets {
+		for _, port := range ss.Ports {
+			if portName != "" && port.Name != portName {
+				continue
+			}
+			for _, addr := range ss.Addresses {
+				proto := string(port.Protocol)
+				if proto == "" {
+					proto = "TCP"
+				}
+				ep = append(ep, Endpoint{
+					IP:       addr.IP,
+					Port:     port.Port,
+					PortName: port.Name,
+					Protocol: proto,
+				})
+			}
+		}
+	}
+	return ep, nil
+}
+
+// GetPods returns all pods in the cluster, regardless of status.
+func (c *Client) GetPods(ctx context.Context) ([]api.Pod, error) {
+	var list api.PodList
+	if err := c.do(ctx, "GET", c.nsEndpoint()+"pods", &list); err != nil {
+		return nil, err
+	}
+	return list.Items, nil
 }
 
 // PodDelete deletes the specified Kubernetes pod.
 func (c *Client) DeletePod(ctx context.Context, podName string) error {
-	url := c.endpointURL + defaultPod + "/" + podName
+	url := c.nsEndpoint() + "pods/" + podName
 	req, err := http.NewRequest("DELETE", url, strings.NewReader(`{"gracePeriodSeconds":0}`))
 	if err != nil {
 		return fmt.Errorf("failed to create request: DELETE %q : %v", url, err)
@@ -254,7 +311,7 @@
 		defer cancel()
 
 		// Make request to Kubernetes API
-		getURL := c.endpointURL + defaultWatchPod + "/" + podName
+		getURL := c.endpointURL + "/watch/namespaces/" + c.namespace + "/pods/" + podName
 		req, err := http.NewRequest("GET", getURL, nil)
 		req.URL.Query().Add("resourceVersion", podResourceVersion)
 		if err != nil {
@@ -317,7 +374,7 @@
 // Retrieve the status of a pod synchronously from the Kube
 // API server.
 func (c *Client) PodStatus(ctx context.Context, podName string) (*api.PodStatus, error) {
-	getURL := c.endpointURL + defaultPod + "/" + podName
+	getURL := c.nsEndpoint() + "pods/" + podName // TODO: escape podName?
 
 	// Make request to Kubernetes API
 	req, err := http.NewRequest("GET", getURL, nil)
@@ -349,7 +406,7 @@
 // in the pod.
 func (c *Client) PodLog(ctx context.Context, podName string) (string, error) {
 	// TODO(evanbrown): support multiple containers
-	url := c.endpointURL + defaultPod + "/" + podName + "/log"
+	url := c.nsEndpoint() + "pods/" + podName + "/log" // TODO: escape podName?
 	req, err := http.NewRequest("GET", url, nil)
 	if err != nil {
 		return "", fmt.Errorf("failed to create request: GET %q : %v", url, err)
@@ -371,26 +428,9 @@
 
 // PodNodes returns the list of nodes that comprise the Kubernetes cluster
 func (c *Client) GetNodes(ctx context.Context) ([]api.Node, error) {
-	url := c.endpointURL + nodes
-	req, err := http.NewRequest("GET", url, nil)
-	if err != nil {
-		return nil, fmt.Errorf("failed to create request: GET %q : %v", url, err)
+	var list api.NodeList
+	if err := c.do(ctx, "GET", c.nsEndpoint()+"nodes", &list); err != nil {
+		return nil, err
 	}
-	res, err := ctxhttp.Do(ctx, c.httpClient, req)
-	if err != nil {
-		return nil, fmt.Errorf("failed to make request: GET %q: %v", url, err)
-	}
-	body, err := ioutil.ReadAll(res.Body)
-	res.Body.Close()
-	if err != nil {
-		return nil, fmt.Errorf("failed to read response body: GET %q: %v, url, err")
-	}
-	if res.StatusCode != http.StatusOK {
-		return nil, fmt.Errorf("http error %d GET %q: %q: %v", res.StatusCode, url, string(body), err)
-	}
-	var nodeList *api.NodeList
-	if err := json.Unmarshal(body, &nodeList); err != nil {
-		return nil, fmt.Errorf("failed to decode node list: %v", err)
-	}
-	return nodeList.Items, nil
+	return list.Items, nil
 }
diff --git a/kubernetes/dialer.go b/kubernetes/dialer.go
index 6439635..9b9aa3e 100644
--- a/kubernetes/dialer.go
+++ b/kubernetes/dialer.go
@@ -7,23 +7,48 @@
 import (
 	"context"
 	"fmt"
+	"math/rand"
 	"net"
 	"strconv"
+	"strings"
+	"time"
 )
 
-// Dialer dials Kubernetes pods.
-//
-// TODO: services also.
-type Dialer struct {
-	kc *Client
+var dialRand = rand.New(rand.NewSource(time.Now().UnixNano()))
+
+// DialService connects to the named service. The service must have only one
+// port. For multi-port services, use DialServicePort.
+func (c *Client) DialService(ctx context.Context, serviceName string) (net.Conn, error) {
+	return c.DialServicePort(ctx, serviceName, "")
 }
 
-func NewDialer(kc *Client) *Dialer {
-	return &Dialer{kc: kc}
+// DialServicePort connects to the named port on the named service.
+// If portName is the empty string, the service must have exactly 1 port.
+func (c *Client) DialServicePort(ctx context.Context, serviceName, portName string) (net.Conn, error) {
+	// TODO: cache the result of GetServiceEndpoints, at least for
+	// a few seconds, to rate-limit calls to the master?
+	eps, err := c.GetServiceEndpoints(ctx, serviceName, portName)
+	if err != nil {
+		return nil, err
+	}
+	if len(eps) == 0 {
+		return nil, fmt.Errorf("no endpoints found for service %q", serviceName)
+	}
+	if portName == "" {
+		firstName := eps[0].PortName
+		for _, p := range eps[1:] {
+			if p.PortName != firstName {
+				return nil, fmt.Errorf("unspecified port name for DialServicePort is ambiguous for service %q (mix of %q, %q, ...)", serviceName, firstName, p.PortName)
+			}
+		}
+	}
+	ep := eps[dialRand.Intn(len(eps))]
+	var dialer net.Dialer
+	return dialer.DialContext(ctx, strings.ToLower(ep.Protocol), net.JoinHostPort(ep.IP, strconv.Itoa(ep.Port)))
 }
 
-func (d *Dialer) Dial(ctx context.Context, podName string, port int) (net.Conn, error) {
-	status, err := d.kc.PodStatus(ctx, podName)
+func (c *Client) DialPod(ctx context.Context, podName string, port int) (net.Conn, error) {
+	status, err := c.PodStatus(ctx, podName)
 	if err != nil {
 		return nil, fmt.Errorf("PodStatus of %q: %v", podName, err)
 	}
diff --git a/kubernetes/gke/gke_test.go b/kubernetes/gke/gke_test.go
index c86efa4..fa7fa5f 100644
--- a/kubernetes/gke/gke_test.go
+++ b/kubernetes/gke/gke_test.go
@@ -20,6 +20,101 @@
 
 // Tests NewClient and also Dialer.
 func TestNewClient(t *testing.T) {
+	ctx := context.Background()
+	foreachCluster(t, func(cl *container.Cluster, kc *kubernetes.Client) {
+		_, err := kc.GetPods(ctx)
+		if err != nil {
+			t.Fatal(err)
+		}
+	})
+}
+
+func TestDialPod(t *testing.T) {
+	var passed bool
+	var candidates int
+	ctx := context.Background()
+	foreachCluster(t, func(cl *container.Cluster, kc *kubernetes.Client) {
+		if passed {
+			return
+		}
+		pods, err := kc.GetPods(ctx)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		for _, pod := range pods {
+			if pod.Status.Phase != "Running" {
+				continue
+			}
+			for _, container := range pod.Spec.Containers {
+				for _, port := range container.Ports {
+					if strings.ToLower(string(port.Protocol)) == "udp" || port.ContainerPort == 0 {
+						continue
+					}
+					candidates++
+					c, err := kc.DialPod(ctx, pod.Name, port.ContainerPort)
+					if err != nil {
+						t.Logf("Dial %q/%q/%d: %v", cl.Name, pod.Name, port.ContainerPort, err)
+						continue
+					}
+					c.Close()
+					t.Logf("Dialed %q/%q/%d.", cl.Name, pod.Name, port.ContainerPort)
+					passed = true
+					return
+				}
+			}
+		}
+	})
+	if candidates == 0 {
+		t.Skip("no pods to dial")
+	}
+	if !passed {
+		t.Errorf("dial failures")
+	}
+}
+
+func TestDialService(t *testing.T) {
+	var passed bool
+	var candidates int
+	ctx := context.Background()
+	foreachCluster(t, func(cl *container.Cluster, kc *kubernetes.Client) {
+		if passed {
+			return
+		}
+		svcs, err := kc.GetServices(ctx)
+		if err != nil {
+			t.Fatal(err)
+		}
+		for _, svc := range svcs {
+			eps, err := kc.GetServiceEndpoints(ctx, svc.Name, "")
+			if err != nil {
+				t.Fatal(err)
+			}
+			if len(eps) != 1 {
+				continue
+			}
+			candidates++
+			conn, err := kc.DialServicePort(ctx, svc.Name, "")
+			if err != nil {
+				t.Logf("%s: DialServicePort(%q) error: %v", cl.Name, svc.Name, err)
+				continue
+			}
+			conn.Close()
+			passed = true
+			t.Logf("Dialed cluster %q service %q.", cl.Name, svc.Name)
+			return
+		}
+
+	})
+	if candidates == 0 {
+		t.Skip("no services to dial")
+	}
+	if !passed {
+		t.Errorf("dial failures")
+	}
+}
+
+func foreachCluster(t *testing.T, fn func(*container.Cluster, *kubernetes.Client)) {
 	if !metadata.OnGCE() {
 		t.Skip("not on GCE; skipping")
 	}
@@ -46,44 +141,12 @@
 	if len(clusters.Clusters) == 0 {
 		t.Skip("no GKE clusters")
 	}
-	var candidates int
 	for _, cl := range clusters.Clusters {
 		kc, err := gke.NewClient(ctx, cl.Name, gke.OptZone(cl.Zone))
 		if err != nil {
 			t.Fatal(err)
 		}
-		defer kc.Close()
-
-		pods, err := kc.GetPods(ctx)
-		if err != nil {
-			t.Fatal(err)
-		}
-		for _, pod := range pods {
-			if pod.Status.Phase != "Running" {
-				continue
-			}
-			for _, container := range pod.Spec.Containers {
-				name := container.Name
-				for _, port := range container.Ports {
-					if strings.ToLower(string(port.Protocol)) == "udp" || port.ContainerPort == 0 {
-						continue
-					}
-					candidates++
-					d := kubernetes.NewDialer(kc)
-					c, err := d.Dial(ctx, name, port.ContainerPort)
-					if err != nil {
-						t.Logf("Dial %q/%q/%d: %v", cl.Name, name, port.ContainerPort, err)
-						continue
-					}
-					c.Close()
-					t.Logf("Dialed %q/%q/%d.", cl.Name, name, port.ContainerPort)
-					return
-				}
-			}
-		}
+		fn(cl, kc)
+		kc.Close()
 	}
-	if candidates == 0 {
-		t.Skip("no pods to dial")
-	}
-	t.Errorf("dial failures")
 }