sandbox, internal/gcpdial: dial sandbox backends directly, add health check

Don't use GCP internal load balancers or health checkers at the moment
because of reasons.

Updates golang/go#25224

Change-Id: If443c25a8a40582f00a1bb51923214ab3b891b28
Reviewed-on: https://go-review.googlesource.com/c/playground/+/216397
Reviewed-by: Alexander Rakoczy <alex@golang.org>
diff --git a/go.mod b/go.mod
index 039b6ea..6fbe381 100644
--- a/go.mod
+++ b/go.mod
@@ -8,5 +8,6 @@
 	golang.org/x/build v0.0.0-20190709001953-30c0e6b89ea0
 	golang.org/x/mod v0.1.1-0.20191119225628-919e395dadcd
 	golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e
+	google.golang.org/api v0.4.0
 	grpc.go4.org v0.0.0-20170609214715-11d0a25b4919
 )
diff --git a/internal/gcpdial/gcpdial.go b/internal/gcpdial/gcpdial.go
new file mode 100644
index 0000000..212e149
--- /dev/null
+++ b/internal/gcpdial/gcpdial.go
@@ -0,0 +1,303 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package gcpdial monitors VM instance groups to let frontends dial
+// them directly without going through an internal load balancer.
+package gcpdial
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"log"
+	"math/rand"
+	"net/http"
+	"strings"
+	"sync"
+	"time"
+
+	compute "google.golang.org/api/compute/v1"
+)
+
+type Dialer struct {
+	lister instanceLister
+
+	mu            sync.Mutex
+	lastInstances []string           // URLs of instances
+	prober        map[string]*prober // URL of instance to its prober
+	ready         map[string]string  // URL of instance to ready IP
+}
+
+type prober struct {
+	d       *Dialer
+	instURL string
+	cancel  func()          // called by Dialer to shut down this dialer
+	ctx     context.Context // context that's canceled from above
+
+	pi *parsedInstance
+
+	// owned by the probeLoop goroutine:
+	ip      string
+	healthy bool
+}
+
+func newProber(d *Dialer, instURL string) *prober {
+	ctx, cancel := context.WithCancel(context.Background())
+	return &prober{
+		d:       d,
+		instURL: instURL,
+		cancel:  cancel,
+		ctx:     ctx,
+	}
+}
+
+func (p *prober) probeLoop() {
+	log.Printf("start prober for %s", p.instURL)
+	defer log.Printf("end prober for %s", p.instURL)
+
+	pi, err := parseInstance(p.instURL)
+	if err != nil {
+		log.Printf("gcpdial: prober %s: failed to parse: %v", p.instURL, err)
+		return
+	}
+	p.pi = pi
+
+	t := time.NewTicker(15 * time.Second)
+	defer t.Stop()
+	for {
+		p.probe()
+		select {
+		case <-p.ctx.Done():
+			return
+		case <-t.C:
+		}
+	}
+}
+
+func (p *prober) probe() {
+	if p.ip == "" && !p.getIP() {
+		return
+	}
+	ctx, cancel := context.WithTimeout(p.ctx, 30*time.Second)
+	defer cancel()
+	req, err := http.NewRequest("GET", "http://"+p.ip+"/healthz", nil)
+	if err != nil {
+		log.Printf("gcpdial: prober %s: NewRequest: %v", p.instURL, err)
+		return
+	}
+	req = req.WithContext(ctx)
+	res, err := http.DefaultClient.Do(req)
+	if res != nil {
+		defer res.Body.Close()
+		defer io.Copy(ioutil.Discard, res.Body)
+	}
+	healthy := err == nil && res.StatusCode == http.StatusOK
+	if healthy == p.healthy {
+		// No change.
+		return
+	}
+	p.healthy = healthy
+
+	p.d.mu.Lock()
+	defer p.d.mu.Unlock()
+	if healthy {
+		if p.d.ready == nil {
+			p.d.ready = map[string]string{}
+		}
+		p.d.ready[p.instURL] = p.ip
+		// TODO: possible optimization: trigger
+		// Dialer.PickIP waiters to wake up rather
+		// than them polling once a second.
+	} else {
+		delete(p.d.ready, p.instURL)
+		var why string
+		if err != nil {
+			why = err.Error()
+		} else {
+			why = res.Status
+		}
+		log.Printf("gcpdial: prober %s: no longer healthy; %v", p.instURL, why)
+	}
+}
+
+// getIP populates p.ip and reports whether it did so.
+func (p *prober) getIP() bool {
+	if p.ip != "" {
+		return true
+	}
+	ctx, cancel := context.WithTimeout(p.ctx, 30*time.Second)
+	defer cancel()
+	svc, err := compute.NewService(ctx)
+	if err != nil {
+		log.Printf("gcpdial: prober %s: NewService: %v", p.instURL, err)
+		return false
+	}
+	inst, err := svc.Instances.Get(p.pi.Project, p.pi.Zone, p.pi.Name).Context(ctx).Do()
+	if err != nil {
+		log.Printf("gcpdial: prober %s: Get: %v", p.instURL, err)
+		return false
+	}
+	var ip string
+	var other []string
+	for _, ni := range inst.NetworkInterfaces {
+		if strings.HasPrefix(ni.NetworkIP, "10.") {
+			ip = ni.NetworkIP
+		} else {
+			other = append(other, ni.NetworkIP)
+		}
+	}
+	if ip == "" {
+		log.Printf("gcpdial: prober %s: didn't find 10.x.x.x IP; found %q", p.instURL, other)
+		return false
+	}
+	p.ip = ip
+	return true
+}
+
+// PickIP returns a randomly healthy IP, waiting until one is available, or until ctx expires.
+func (d *Dialer) PickIP(ctx context.Context) (ip string, err error) {
+	for {
+		if ip, ok := d.pickIP(); ok {
+			return ip, nil
+		}
+		select {
+		case <-ctx.Done():
+			return "", ctx.Err()
+		case <-time.After(time.Second):
+		}
+	}
+}
+
+func (d *Dialer) pickIP() (string, bool) {
+	d.mu.Lock()
+	defer d.mu.Unlock()
+	if len(d.ready) == 0 {
+		return "", false
+	}
+	num := rand.Intn(len(d.ready))
+	for _, v := range d.ready {
+		if num > 0 {
+			num--
+			continue
+		}
+		return v, true
+	}
+	panic("not reachable")
+}
+
+func (d *Dialer) poll() {
+	for {
+		ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+		res, err := d.lister.ListInstances(ctx)
+		cancel()
+		if err != nil {
+			log.Printf("gcpdial: polling %v: %v", d.lister, err)
+			time.Sleep(10 * time.Second)
+			continue
+		}
+
+		want := map[string]bool{} // the res []string turned into a set
+		for _, instURL := range res {
+			want[instURL] = true
+		}
+
+		d.mu.Lock()
+		// Stop + remove any health check probers that no longer appear in the
+		// instance group.
+		for instURL, prober := range d.prober {
+			if !want[instURL] {
+				prober.cancel()
+				delete(d.prober, instURL)
+			}
+		}
+		// And start any new health check probers that are newly added
+		// (or newly known at least) to the instance group.
+		for _, instURL := range res {
+			if _, ok := d.prober[instURL]; ok {
+				continue
+			}
+			p := newProber(d, instURL)
+			go p.probeLoop()
+			if d.prober == nil {
+				d.prober = map[string]*prober{}
+			}
+			d.prober[instURL] = p
+		}
+		d.lastInstances = res
+		d.mu.Unlock()
+	}
+}
+
+// NewRegionInstanceGroupDialer returns a new dialer that dials named
+// regional instance group in the provided project and region.
+//
+// It begins polling immediately, and there's no way to stop it.
+// (Until we need one)
+func NewRegionInstanceGroupDialer(project, region, group string) *Dialer {
+	d := &Dialer{
+		lister: regionInstanceGroupLister{project, region, group},
+	}
+	go d.poll()
+	return d
+}
+
+// instanceLister is something that can list the current set of VMs.
+//
+// The idea is that we'll have both zonal and regional instance group listers,
+// but currently we only have regionInstanceGroupLister below.
+type instanceLister interface {
+	// ListInstances returns a list of instances in their API URL form.
+	//
+	// The API URL form is parseable by the parseInstance func. See its docs.
+	ListInstances(context.Context) ([]string, error)
+}
+
+// regionInstanceGroupLister is an instanceLister implementation that watches a regional
+// instance group for changes to its set of VMs.
+type regionInstanceGroupLister struct {
+	project, region, group string
+}
+
+func (rig regionInstanceGroupLister) ListInstances(ctx context.Context) (ret []string, err error) {
+	svc, err := compute.NewService(ctx)
+	if err != nil {
+		return nil, err
+	}
+	rigSvc := svc.RegionInstanceGroups
+	insts, err := rigSvc.ListInstances(rig.project, rig.region, rig.group, &compute.RegionInstanceGroupsListInstancesRequest{
+		InstanceState: "RUNNING",
+		PortName:      "", // all
+	}).Context(ctx).MaxResults(500).Do()
+	if err != nil {
+		return nil, err
+	}
+	// TODO: pagination for really large sets? Currently we truncate the results
+	// to the first 500 VMs, which seems like plenty for now.
+	// 500 is the maximum the API supports; see:
+	// https://pkg.go.dev/google.golang.org/api/compute/v1?tab=doc#RegionInstanceGroupsListInstancesCall.MaxResults
+	for _, it := range insts.Items {
+		ret = append(ret, it.Instance)
+	}
+	return ret, nil
+}
+
+// parsedInstance contains the project, zone, and name of a VM.
+type parsedInstance struct {
+	Project, Zone, Name string
+}
+
+// parseInstance parses e.g. "https://www.googleapis.com/compute/v1/projects/golang-org/zones/us-central1-c/instances/playsandbox-7sj8" into its parts.
+func parseInstance(u string) (*parsedInstance, error) {
+	const pfx = "https://www.googleapis.com/compute/v1/projects/"
+	if !strings.HasPrefix(u, pfx) {
+		return nil, fmt.Errorf("failed to parse instance %q; doesn't begin with %q", u, pfx)
+	}
+	u = u[len(pfx):] // "golang-org/zones/us-central1-c/instances/playsandbox-7sj8"
+	f := strings.Split(u, "/")
+	if len(f) != 5 || f[1] != "zones" || f[3] != "instances" {
+		return nil, fmt.Errorf("failed to parse instance %q; unexpected format", u)
+	}
+	return &parsedInstance{f[0], f[2], f[4]}, nil
+}
diff --git a/internal/gcpdial/gcpdialtool/gcpdialtool.go b/internal/gcpdial/gcpdialtool/gcpdialtool.go
new file mode 100644
index 0000000..1bcbd6c
--- /dev/null
+++ b/internal/gcpdial/gcpdialtool/gcpdialtool.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The gcpdialtool command is an interactive validation tool for the
+// gcpdial packge.
+package main
+
+import (
+	"context"
+	"flag"
+	"log"
+	"os"
+	"time"
+
+	"golang.org/x/playground/internal/gcpdial"
+)
+
+var (
+	proj   = flag.String("project", "golang-org", "GCP project name")
+	region = flag.String("region", "us-central1", "GCP region")
+	group  = flag.String("group", "play-sandbox-rigm", "regional instance group name")
+)
+
+func main() {
+	flag.Parse()
+	log.SetOutput(os.Stdout)
+	log.SetFlags(log.Flags() | log.Lmicroseconds)
+
+	log.Printf("starting")
+	d := gcpdial.NewRegionInstanceGroupDialer(*proj, *region, *group)
+
+	ctx := context.Background()
+	for {
+		ip, err := d.PickIP(ctx)
+		if err != nil {
+			log.Fatal(err)
+		}
+		log.Printf("picked %v", ip)
+		time.Sleep(time.Second)
+	}
+}
diff --git a/sandbox.go b/sandbox.go
index e55e814..58431d4 100644
--- a/sandbox.go
+++ b/sandbox.go
@@ -20,6 +20,7 @@
 	"go/token"
 	"io"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"os"
 	"os/exec"
@@ -27,11 +28,13 @@
 	"runtime"
 	"strconv"
 	"strings"
+	"sync"
 	"text/template"
 	"time"
 
 	"cloud.google.com/go/compute/metadata"
 	"github.com/bradfitz/gomemcache/memcache"
+	"golang.org/x/playground/internal/gcpdial"
 	"golang.org/x/playground/sandbox/sandboxtypes"
 )
 
@@ -460,7 +463,7 @@
 			req.Header.Add("X-Argument", testParam)
 		}
 		req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewReader(exeBytes)), nil }
-		res, err := http.DefaultClient.Do(req)
+		res, err := sandboxBackendClient().Do(req)
 		if err != nil {
 			return nil, err
 		}
@@ -586,6 +589,47 @@
 	panic(fmt.Sprintf("no SANDBOX_BACKEND_URL environment and no default defined for project %q", id))
 }
 
+var sandboxBackendOnce struct {
+	sync.Once
+	c *http.Client
+}
+
+func sandboxBackendClient() *http.Client {
+	sandboxBackendOnce.Do(initSandboxBackendClient)
+	return sandboxBackendOnce.c
+}
+
+// initSandboxBackendClient runs from a sync.Once and initializes
+// sandboxBackendOnce.c with the *http.Client we'll use to contact the
+// sandbox execution backend.
+func initSandboxBackendClient() {
+	id, _ := metadata.ProjectID()
+	switch id {
+	case "golang-org":
+		// For production, use a funky Transport dialer that
+		// contacts backend directly, without going through an
+		// internal load balancer, due to internal GCP
+		// reasons, which we might resolve later. This might
+		// be a temporary hack.
+		tr := http.DefaultTransport.(*http.Transport).Clone()
+		rigd := gcpdial.NewRegionInstanceGroupDialer("golang-org", "us-central1", "play-sandbox-rigm")
+		tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+			if addr == "sandbox.play-sandbox-fwd.il4.us-central1.lb.golang-org.internal:80" {
+				ip, err := rigd.PickIP(ctx)
+				if err != nil {
+					return nil, err
+				}
+				addr = net.JoinHostPort(ip, "80") // and fallthrough
+			}
+			var d net.Dialer
+			return d.DialContext(ctx, netw, addr)
+		}
+		sandboxBackendOnce.c = &http.Client{Transport: tr}
+	default:
+		sandboxBackendOnce.c = http.DefaultClient
+	}
+}
+
 const healthProg = `
 package main
 
diff --git a/sandbox/sandbox.go b/sandbox/sandbox.go
index d43f3e9..41f1f9b 100644
--- a/sandbox/sandbox.go
+++ b/sandbox/sandbox.go
@@ -96,6 +96,8 @@
 	c.waitVal = c.cmd.Wait()
 }
 
+var httpServer *http.Server
+
 func main() {
 	flag.Parse()
 	if *mode == "contained" {
@@ -132,7 +134,8 @@
 
 	go makeWorkers()
 
-	log.Fatal(http.ListenAndServe(*listenAddr, nil))
+	httpServer = &http.Server{Addr: *listenAddr}
+	log.Fatal(httpServer.ListenAndServe())
 }
 
 func handleSignals() {
@@ -142,10 +145,68 @@
 	log.Fatalf("closing on signal %d: %v", s, s)
 }
 
+var healthStatus struct {
+	sync.Mutex
+	lastCheck time.Time
+	lastVal   error
+}
+
+func getHealthCached() error {
+	healthStatus.Lock()
+	defer healthStatus.Unlock()
+	const recentEnough = 5 * time.Second
+	if healthStatus.lastCheck.After(time.Now().Add(-recentEnough)) {
+		return healthStatus.lastVal
+	}
+
+	err := checkHealth()
+	if healthStatus.lastVal == nil && err != nil {
+		// On transition from healthy to unhealthy, close all
+		// idle HTTP connections so clients with them open
+		// don't reuse them. TODO: remove this if/when we
+		// switch away from direct load balancing between
+		// frontends and this sandbox backend.
+		httpServer.SetKeepAlivesEnabled(false) // side effect of closing all idle ones
+		httpServer.SetKeepAlivesEnabled(true)  // and restore it back to normal
+	}
+	healthStatus.lastVal = err
+	healthStatus.lastCheck = time.Now()
+	return err
+}
+
+// checkHealth does a health check, without any caching. It's called via
+// getHealthCached.
+func checkHealth() error {
+	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+	defer cancel()
+	c, err := getContainer(ctx)
+	if err != nil {
+		return fmt.Errorf("failed to get a sandbox container: %v", err)
+	}
+	// TODO: execute something too? for now we just check that sandboxed containers
+	// are available.
+	closed := make(chan struct{})
+	go func() {
+		c.Close()
+		close(closed)
+	}()
+	select {
+	case <-closed:
+		// success.
+		return nil
+	case <-ctx.Done():
+		return fmt.Errorf("timeout closing sandbox container")
+	}
+}
+
 func healthHandler(w http.ResponseWriter, r *http.Request) {
+	// TODO: split into liveness & readiness checks?
+	if err := getHealthCached(); err != nil {
+		w.WriteHeader(http.StatusInternalServerError)
+		fmt.Fprintf(w, "health check failure: %v\n", err)
+		return
+	}
 	io.WriteString(w, "OK\n")
-	// TODO: more? split into liveness & readiness checks? check
-	// number of active/stuck containers, memory?
 }
 
 func rootHandler(w http.ResponseWriter, r *http.Request) {
diff --git a/sandbox/sandbox.tf b/sandbox/sandbox.tf
index e6c84b3..14de7cb 100644
--- a/sandbox/sandbox.tf
+++ b/sandbox/sandbox.tf
@@ -1,8 +1,6 @@
 # TODO: move the network configuration into terraform too? It was created by hand with:
 # gcloud compute networks subnets update golang --region=us-central1 --enable-private-ip-google-access
 #
-# Likewise, the firewall rules for health checking were created imperatively based on
-# https://cloud.google.com/load-balancing/docs/health-checks#firewall_rules
 
 terraform {
   backend "gcs" {
@@ -47,9 +45,6 @@
   network_interface {
     network = "golang"
   }
-  # Allow both "non-legacy" and "legacy" health checks, so we can change types in the future.
-  # See https://cloud.google.com/load-balancing/docs/health-checks
-  tags = ["allow-health-checks", "allow-network-lb-health-checks"]
   service_account {
     scopes = ["logging-write", "storage-ro"]
   }
@@ -76,7 +71,7 @@
 
   autoscaling_policy {
     max_replicas    = 10
-    min_replicas    = 2
+    min_replicas    = 3
     cooldown_period = 60
 
     cpu_utilization {
@@ -101,49 +96,9 @@
     name = "http"
     port = 80
   }
-
-  auto_healing_policies {
-    health_check      = "${google_compute_health_check.default.self_link}"
-    initial_delay_sec = 30
-  }
 }
 
 data "google_compute_region_instance_group" "rig" {
   provider  = "google-beta"
   self_link = "${google_compute_region_instance_group_manager.rigm.instance_group}"
 }
-
-resource "google_compute_health_check" "default" {
-  name                = "play-sandbox-rigm-health-check"
-  check_interval_sec  = 5
-  timeout_sec         = 5
-  healthy_threshold   = 2
-  unhealthy_threshold = 10 # 50 seconds
-  http_health_check {
-    request_path = "/healthz"
-    port         = 80
-  }
-}
-
-resource "google_compute_region_backend_service" "default" {
-  name          = "play-sandbox-backend-service"
-  region        = "us-central1"
-  health_checks = ["${google_compute_health_check.default.self_link}"]
-  backend {
-    group = "${data.google_compute_region_instance_group.rig.self_link}"
-  }
-}
-
-resource "google_compute_forwarding_rule" "default" {
-  name                  = "play-sandbox-fwd"
-  region                = "us-central1"
-  network               = "golang"
-  ports                 = ["80"]
-  load_balancing_scheme = "INTERNAL"
-  ip_protocol           = "TCP"
-  backend_service       = "${google_compute_region_backend_service.default.self_link}"
-
-  # Adding a service label gives us a DNS name:
-  # sandbox.play-sandbox-fwd.il4.us-central1.lb.golang-org.internal
-  service_label = "sandbox"
-}