http2: implement client initiated graceful shutdown

Sends a GOAWAY frame and wait for the in-flight streams to complete.

Fixes golang/go#17292

Change-Id: I2b7dd61446f4ffd9c820fbb21d1233c3b3ad1ba8
Reviewed-on: https://go-review.googlesource.com/30076
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/go17.go b/http2/go17.go
index 47b7fae..bf3a7c1 100644
--- a/http2/go17.go
+++ b/http2/go17.go
@@ -18,6 +18,8 @@
 	context.Context
 }
 
+var errCanceled = context.Canceled
+
 func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
 	ctx, cancel = context.WithCancel(context.Background())
 	ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
@@ -104,3 +106,8 @@
 func (cc *ClientConn) Ping(ctx context.Context) error {
 	return cc.ping(ctx)
 }
+
+// Shutdown gracefully closes the client connection, waiting for running streams to complete.
+func (cc *ClientConn) Shutdown(ctx context.Context) error {
+	return cc.shutdown(ctx)
+}
diff --git a/http2/not_go17.go b/http2/not_go17.go
index 140434a..976163f 100644
--- a/http2/not_go17.go
+++ b/http2/not_go17.go
@@ -8,6 +8,7 @@
 
 import (
 	"crypto/tls"
+	"errors"
 	"net"
 	"net/http"
 	"time"
@@ -18,6 +19,8 @@
 	Err() error
 }
 
+var errCanceled = errors.New("canceled")
+
 type fakeContext struct{}
 
 func (fakeContext) Done() <-chan struct{} { return nil }
@@ -84,4 +87,8 @@
 	return cc.ping(ctx)
 }
 
+func (cc *ClientConn) Shutdown(ctx contextContext) error {
+	return cc.shutdown(ctx)
+}
+
 func (t *Transport) idleConnTimeout() time.Duration { return 0 }
diff --git a/http2/transport.go b/http2/transport.go
index a67112c..d474377 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -159,6 +159,7 @@
 	cond            *sync.Cond // hold mu; broadcast on flow/closed changes
 	flow            flow       // our conn-level flow control quota (cs.flow is per stream)
 	inflow          flow       // peer's conn-level flow control
+	closing         bool
 	closed          bool
 	wantSettingsAck bool                     // we sent a SETTINGS frame and haven't heard back
 	goAway          *GoAwayFrame             // if non-nil, the GoAwayFrame we received
@@ -634,7 +635,7 @@
 	if cc.singleUse && cc.nextStreamID > 1 {
 		return false
 	}
-	return cc.goAway == nil && !cc.closed &&
+	return cc.goAway == nil && !cc.closed && !cc.closing &&
 		int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
 }
 
@@ -665,6 +666,88 @@
 	cc.tconn.Close()
 }
 
+var shutdownEnterWaitStateHook = func() {}
+
+// Shutdown gracefully close the client connection, waiting for running streams to complete.
+// Public implementation is in go17.go and not_go17.go
+func (cc *ClientConn) shutdown(ctx contextContext) error {
+	if err := cc.sendGoAway(); err != nil {
+		return err
+	}
+	// Wait for all in-flight streams to complete or connection to close
+	done := make(chan error, 1)
+	cancelled := false // guarded by cc.mu
+	go func() {
+		cc.mu.Lock()
+		defer cc.mu.Unlock()
+		for {
+			if len(cc.streams) == 0 || cc.closed {
+				cc.closed = true
+				done <- cc.tconn.Close()
+				break
+			}
+			if cancelled {
+				break
+			}
+			cc.cond.Wait()
+		}
+	}()
+	shutdownEnterWaitStateHook()
+	select {
+	case err := <-done:
+		return err
+	case <-ctx.Done():
+		cc.mu.Lock()
+		// Free the goroutine above
+		cancelled = true
+		cc.cond.Broadcast()
+		cc.mu.Unlock()
+		return ctx.Err()
+	}
+}
+
+func (cc *ClientConn) sendGoAway() error {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	cc.wmu.Lock()
+	defer cc.wmu.Unlock()
+	if cc.closing {
+		// GOAWAY sent already
+		return nil
+	}
+	// Send a graceful shutdown frame to server
+	maxStreamID := cc.nextStreamID
+	if err := cc.fr.WriteGoAway(maxStreamID, ErrCodeNo, nil); err != nil {
+		return err
+	}
+	if err := cc.bw.Flush(); err != nil {
+		return err
+	}
+	// Prevent new requests
+	cc.closing = true
+	return nil
+}
+
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *ClientConn) Close() error {
+	cc.mu.Lock()
+	defer cc.cond.Broadcast()
+	defer cc.mu.Unlock()
+	err := errors.New("http2: client connection force closed via ClientConn.Close")
+	for id, cs := range cc.streams {
+		select {
+		case cs.resc <- resAndError{err: err}:
+		default:
+		}
+		cs.bufPipe.CloseWithError(err)
+		delete(cc.streams, id)
+	}
+	cc.closed = true
+	return cc.tconn.Close()
+}
+
 const maxAllocFrameSize = 512 << 10
 
 // frameBuffer returns a scratch buffer suitable for writing DATA frames.
diff --git a/http2/transport_test.go b/http2/transport_test.go
index d9cf115..963a723 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -30,6 +30,7 @@
 	"testing"
 	"time"
 
+	"golang.org/x/net/context"
 	"golang.org/x/net/http2/hpack"
 )
 
@@ -41,12 +42,13 @@
 
 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
 
-type testContext struct{}
+var canceledCtx context.Context
 
-func (testContext) Done() <-chan struct{}                   { return make(chan struct{}) }
-func (testContext) Err() error                              { panic("should not be called") }
-func (testContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false }
-func (testContext) Value(key interface{}) interface{}       { return nil }
+func init() {
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+	canceledCtx = ctx
+}
 
 func TestTransportExternal(t *testing.T) {
 	if !*extNet {
@@ -3054,7 +3056,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	if err = cc.Ping(testContext{}); err != nil {
+	if err = cc.Ping(context.Background()); err != nil {
 		t.Fatal(err)
 	}
 }
@@ -3856,3 +3858,191 @@
 	b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) })
 	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) })
 }
+
+func activeStreams(cc *ClientConn) int {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	return len(cc.streams)
+}
+
+type closeMode int
+
+const (
+	closeAtHeaders closeMode = iota
+	closeAtBody
+	shutdown
+	shutdownCancel
+)
+
+// See golang.org/issue/17292
+func testClientConnClose(t *testing.T, closeMode closeMode) {
+	clientDone := make(chan struct{})
+	defer close(clientDone)
+	handlerDone := make(chan struct{})
+	closeDone := make(chan struct{})
+	beforeHeader := func() {}
+	bodyWrite := func(w http.ResponseWriter) {}
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		defer close(handlerDone)
+		beforeHeader()
+		w.WriteHeader(http.StatusOK)
+		w.(http.Flusher).Flush()
+		bodyWrite(w)
+		select {
+		case <-w.(http.CloseNotifier).CloseNotify():
+			// client closed connection before completion
+			if closeMode == shutdown || closeMode == shutdownCancel {
+				t.Error("expected request to complete")
+			}
+		case <-clientDone:
+			if closeMode == closeAtHeaders || closeMode == closeAtBody {
+				t.Error("expected connection closed by client")
+			}
+		}
+	}, optOnlyServer)
+	defer st.Close()
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+	req, err := http.NewRequest("GET", st.ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if closeMode == closeAtHeaders {
+		beforeHeader = func() {
+			if err := cc.Close(); err != nil {
+				t.Error(err)
+			}
+			close(closeDone)
+		}
+	}
+	var sendBody chan struct{}
+	if closeMode == closeAtBody {
+		sendBody = make(chan struct{})
+		bodyWrite = func(w http.ResponseWriter) {
+			<-sendBody
+			b := make([]byte, 32)
+			w.Write(b)
+			w.(http.Flusher).Flush()
+			if err := cc.Close(); err != nil {
+				t.Errorf("unexpected ClientConn close error: %v", err)
+			}
+			close(closeDone)
+			w.Write(b)
+			w.(http.Flusher).Flush()
+		}
+	}
+	res, err := cc.RoundTrip(req)
+	if res != nil {
+		defer res.Body.Close()
+	}
+	if closeMode == closeAtHeaders {
+		got := fmt.Sprint(err)
+		want := "http2: client connection force closed via ClientConn.Close"
+		if got != want {
+			t.Fatalf("RoundTrip error = %v, want %v", got, want)
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("RoundTrip: %v", err)
+		}
+		if got, want := activeStreams(cc), 1; got != want {
+			t.Errorf("got %d active streams, want %d", got, want)
+		}
+	}
+	switch closeMode {
+	case shutdownCancel:
+		if err = cc.Shutdown(canceledCtx); err != errCanceled {
+			t.Errorf("got %v, want %v", err, errCanceled)
+		}
+		if cc.closing == false {
+			t.Error("expected closing to be true")
+		}
+		if cc.CanTakeNewRequest() == true {
+			t.Error("CanTakeNewRequest to return false")
+		}
+		if v, want := len(cc.streams), 1; v != want {
+			t.Errorf("expected %d active streams, got %d", want, v)
+		}
+		clientDone <- struct{}{}
+		<-handlerDone
+	case shutdown:
+		wait := make(chan struct{})
+		shutdownEnterWaitStateHook = func() {
+			close(wait)
+			shutdownEnterWaitStateHook = func() {}
+		}
+		defer func() { shutdownEnterWaitStateHook = func() {} }()
+		shutdown := make(chan struct{}, 1)
+		go func() {
+			if err = cc.Shutdown(context.Background()); err != nil {
+				t.Error(err)
+			}
+			close(shutdown)
+		}()
+		// Let the shutdown to enter wait state
+		<-wait
+		cc.mu.Lock()
+		if cc.closing == false {
+			t.Error("expected closing to be true")
+		}
+		cc.mu.Unlock()
+		if cc.CanTakeNewRequest() == true {
+			t.Error("CanTakeNewRequest to return false")
+		}
+		if got, want := activeStreams(cc), 1; got != want {
+			t.Errorf("got %d active streams, want %d", got, want)
+		}
+		// Let the active request finish
+		clientDone <- struct{}{}
+		// Wait for the shutdown to end
+		select {
+		case <-shutdown:
+		case <-time.After(2 * time.Second):
+			t.Fatal("expected server connection to close")
+		}
+	case closeAtHeaders, closeAtBody:
+		if closeMode == closeAtBody {
+			go close(sendBody)
+			if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
+				t.Error("expected a Copy error, got nil")
+			}
+		}
+		<-closeDone
+		if got, want := activeStreams(cc), 0; got != want {
+			t.Errorf("got %d active streams, want %d", got, want)
+		}
+		// wait for server to get the connection close notice
+		select {
+		case <-handlerDone:
+		case <-time.After(2 * time.Second):
+			t.Fatal("expected server connection to close")
+		}
+	}
+}
+
+// The client closes the connection just after the server got the client's HEADERS
+// frame, but before the server sends its HEADERS response back. The expected
+// result is an error on RoundTrip explaining the client closed the connection.
+func TestClientConnCloseAtHeaders(t *testing.T) {
+	testClientConnClose(t, closeAtHeaders)
+}
+
+// The client closes the connection between two server's response DATA frames.
+// The expected behavior is a response body io read error on the client.
+func TestClientConnCloseAtBody(t *testing.T) {
+	testClientConnClose(t, closeAtBody)
+}
+
+// The client sends a GOAWAY frame before the server finished processing a request.
+// We expect the connection not to close until the request is completed.
+func TestClientConnShutdown(t *testing.T) {
+	testClientConnClose(t, shutdown)
+}
+
+// The client sends a GOAWAY frame before the server finishes processing a request,
+// but cancels the passed context before the request is completed. The expected
+// behavior is the client closing the connection after the context is canceled.
+func TestClientConnShutdownCancel(t *testing.T) {
+	testClientConnClose(t, shutdownCancel)
+}