http2: implement client initiated graceful shutdown
Sends a GOAWAY frame and wait for the in-flight streams to complete.
Fixes golang/go#17292
Change-Id: I2b7dd61446f4ffd9c820fbb21d1233c3b3ad1ba8
Reviewed-on: https://go-review.googlesource.com/30076
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/go17.go b/http2/go17.go
index 47b7fae..bf3a7c1 100644
--- a/http2/go17.go
+++ b/http2/go17.go
@@ -18,6 +18,8 @@
context.Context
}
+var errCanceled = context.Canceled
+
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
ctx, cancel = context.WithCancel(context.Background())
ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
@@ -104,3 +106,8 @@
func (cc *ClientConn) Ping(ctx context.Context) error {
return cc.ping(ctx)
}
+
+// Shutdown gracefully closes the client connection, waiting for running streams to complete.
+func (cc *ClientConn) Shutdown(ctx context.Context) error {
+ return cc.shutdown(ctx)
+}
diff --git a/http2/not_go17.go b/http2/not_go17.go
index 140434a..976163f 100644
--- a/http2/not_go17.go
+++ b/http2/not_go17.go
@@ -8,6 +8,7 @@
import (
"crypto/tls"
+ "errors"
"net"
"net/http"
"time"
@@ -18,6 +19,8 @@
Err() error
}
+var errCanceled = errors.New("canceled")
+
type fakeContext struct{}
func (fakeContext) Done() <-chan struct{} { return nil }
@@ -84,4 +87,8 @@
return cc.ping(ctx)
}
+func (cc *ClientConn) Shutdown(ctx contextContext) error {
+ return cc.shutdown(ctx)
+}
+
func (t *Transport) idleConnTimeout() time.Duration { return 0 }
diff --git a/http2/transport.go b/http2/transport.go
index a67112c..d474377 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -159,6 +159,7 @@
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow flow // our conn-level flow control quota (cs.flow is per stream)
inflow flow // peer's conn-level flow control
+ closing bool
closed bool
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
@@ -634,7 +635,7 @@
if cc.singleUse && cc.nextStreamID > 1 {
return false
}
- return cc.goAway == nil && !cc.closed &&
+ return cc.goAway == nil && !cc.closed && !cc.closing &&
int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
}
@@ -665,6 +666,88 @@
cc.tconn.Close()
}
+var shutdownEnterWaitStateHook = func() {}
+
+// Shutdown gracefully close the client connection, waiting for running streams to complete.
+// Public implementation is in go17.go and not_go17.go
+func (cc *ClientConn) shutdown(ctx contextContext) error {
+ if err := cc.sendGoAway(); err != nil {
+ return err
+ }
+ // Wait for all in-flight streams to complete or connection to close
+ done := make(chan error, 1)
+ cancelled := false // guarded by cc.mu
+ go func() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ for {
+ if len(cc.streams) == 0 || cc.closed {
+ cc.closed = true
+ done <- cc.tconn.Close()
+ break
+ }
+ if cancelled {
+ break
+ }
+ cc.cond.Wait()
+ }
+ }()
+ shutdownEnterWaitStateHook()
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ cc.mu.Lock()
+ // Free the goroutine above
+ cancelled = true
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+ return ctx.Err()
+ }
+}
+
+func (cc *ClientConn) sendGoAway() error {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if cc.closing {
+ // GOAWAY sent already
+ return nil
+ }
+ // Send a graceful shutdown frame to server
+ maxStreamID := cc.nextStreamID
+ if err := cc.fr.WriteGoAway(maxStreamID, ErrCodeNo, nil); err != nil {
+ return err
+ }
+ if err := cc.bw.Flush(); err != nil {
+ return err
+ }
+ // Prevent new requests
+ cc.closing = true
+ return nil
+}
+
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *ClientConn) Close() error {
+ cc.mu.Lock()
+ defer cc.cond.Broadcast()
+ defer cc.mu.Unlock()
+ err := errors.New("http2: client connection force closed via ClientConn.Close")
+ for id, cs := range cc.streams {
+ select {
+ case cs.resc <- resAndError{err: err}:
+ default:
+ }
+ cs.bufPipe.CloseWithError(err)
+ delete(cc.streams, id)
+ }
+ cc.closed = true
+ return cc.tconn.Close()
+}
+
const maxAllocFrameSize = 512 << 10
// frameBuffer returns a scratch buffer suitable for writing DATA frames.
diff --git a/http2/transport_test.go b/http2/transport_test.go
index d9cf115..963a723 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -30,6 +30,7 @@
"testing"
"time"
+ "golang.org/x/net/context"
"golang.org/x/net/http2/hpack"
)
@@ -41,12 +42,13 @@
var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
-type testContext struct{}
+var canceledCtx context.Context
-func (testContext) Done() <-chan struct{} { return make(chan struct{}) }
-func (testContext) Err() error { panic("should not be called") }
-func (testContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false }
-func (testContext) Value(key interface{}) interface{} { return nil }
+func init() {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ canceledCtx = ctx
+}
func TestTransportExternal(t *testing.T) {
if !*extNet {
@@ -3054,7 +3056,7 @@
if err != nil {
t.Fatal(err)
}
- if err = cc.Ping(testContext{}); err != nil {
+ if err = cc.Ping(context.Background()); err != nil {
t.Fatal(err)
}
}
@@ -3856,3 +3858,191 @@
b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) })
b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) })
}
+
+func activeStreams(cc *ClientConn) int {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return len(cc.streams)
+}
+
+type closeMode int
+
+const (
+ closeAtHeaders closeMode = iota
+ closeAtBody
+ shutdown
+ shutdownCancel
+)
+
+// See golang.org/issue/17292
+func testClientConnClose(t *testing.T, closeMode closeMode) {
+ clientDone := make(chan struct{})
+ defer close(clientDone)
+ handlerDone := make(chan struct{})
+ closeDone := make(chan struct{})
+ beforeHeader := func() {}
+ bodyWrite := func(w http.ResponseWriter) {}
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ defer close(handlerDone)
+ beforeHeader()
+ w.WriteHeader(http.StatusOK)
+ w.(http.Flusher).Flush()
+ bodyWrite(w)
+ select {
+ case <-w.(http.CloseNotifier).CloseNotify():
+ // client closed connection before completion
+ if closeMode == shutdown || closeMode == shutdownCancel {
+ t.Error("expected request to complete")
+ }
+ case <-clientDone:
+ if closeMode == closeAtHeaders || closeMode == closeAtBody {
+ t.Error("expected connection closed by client")
+ }
+ }
+ }, optOnlyServer)
+ defer st.Close()
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if closeMode == closeAtHeaders {
+ beforeHeader = func() {
+ if err := cc.Close(); err != nil {
+ t.Error(err)
+ }
+ close(closeDone)
+ }
+ }
+ var sendBody chan struct{}
+ if closeMode == closeAtBody {
+ sendBody = make(chan struct{})
+ bodyWrite = func(w http.ResponseWriter) {
+ <-sendBody
+ b := make([]byte, 32)
+ w.Write(b)
+ w.(http.Flusher).Flush()
+ if err := cc.Close(); err != nil {
+ t.Errorf("unexpected ClientConn close error: %v", err)
+ }
+ close(closeDone)
+ w.Write(b)
+ w.(http.Flusher).Flush()
+ }
+ }
+ res, err := cc.RoundTrip(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if closeMode == closeAtHeaders {
+ got := fmt.Sprint(err)
+ want := "http2: client connection force closed via ClientConn.Close"
+ if got != want {
+ t.Fatalf("RoundTrip error = %v, want %v", got, want)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("RoundTrip: %v", err)
+ }
+ if got, want := activeStreams(cc), 1; got != want {
+ t.Errorf("got %d active streams, want %d", got, want)
+ }
+ }
+ switch closeMode {
+ case shutdownCancel:
+ if err = cc.Shutdown(canceledCtx); err != errCanceled {
+ t.Errorf("got %v, want %v", err, errCanceled)
+ }
+ if cc.closing == false {
+ t.Error("expected closing to be true")
+ }
+ if cc.CanTakeNewRequest() == true {
+ t.Error("CanTakeNewRequest to return false")
+ }
+ if v, want := len(cc.streams), 1; v != want {
+ t.Errorf("expected %d active streams, got %d", want, v)
+ }
+ clientDone <- struct{}{}
+ <-handlerDone
+ case shutdown:
+ wait := make(chan struct{})
+ shutdownEnterWaitStateHook = func() {
+ close(wait)
+ shutdownEnterWaitStateHook = func() {}
+ }
+ defer func() { shutdownEnterWaitStateHook = func() {} }()
+ shutdown := make(chan struct{}, 1)
+ go func() {
+ if err = cc.Shutdown(context.Background()); err != nil {
+ t.Error(err)
+ }
+ close(shutdown)
+ }()
+ // Let the shutdown to enter wait state
+ <-wait
+ cc.mu.Lock()
+ if cc.closing == false {
+ t.Error("expected closing to be true")
+ }
+ cc.mu.Unlock()
+ if cc.CanTakeNewRequest() == true {
+ t.Error("CanTakeNewRequest to return false")
+ }
+ if got, want := activeStreams(cc), 1; got != want {
+ t.Errorf("got %d active streams, want %d", got, want)
+ }
+ // Let the active request finish
+ clientDone <- struct{}{}
+ // Wait for the shutdown to end
+ select {
+ case <-shutdown:
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected server connection to close")
+ }
+ case closeAtHeaders, closeAtBody:
+ if closeMode == closeAtBody {
+ go close(sendBody)
+ if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
+ t.Error("expected a Copy error, got nil")
+ }
+ }
+ <-closeDone
+ if got, want := activeStreams(cc), 0; got != want {
+ t.Errorf("got %d active streams, want %d", got, want)
+ }
+ // wait for server to get the connection close notice
+ select {
+ case <-handlerDone:
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected server connection to close")
+ }
+ }
+}
+
+// The client closes the connection just after the server got the client's HEADERS
+// frame, but before the server sends its HEADERS response back. The expected
+// result is an error on RoundTrip explaining the client closed the connection.
+func TestClientConnCloseAtHeaders(t *testing.T) {
+ testClientConnClose(t, closeAtHeaders)
+}
+
+// The client closes the connection between two server's response DATA frames.
+// The expected behavior is a response body io read error on the client.
+func TestClientConnCloseAtBody(t *testing.T) {
+ testClientConnClose(t, closeAtBody)
+}
+
+// The client sends a GOAWAY frame before the server finished processing a request.
+// We expect the connection not to close until the request is completed.
+func TestClientConnShutdown(t *testing.T) {
+ testClientConnClose(t, shutdown)
+}
+
+// The client sends a GOAWAY frame before the server finishes processing a request,
+// but cancels the passed context before the request is completed. The expected
+// behavior is the client closing the connection after the context is canceled.
+func TestClientConnShutdownCancel(t *testing.T) {
+ testClientConnClose(t, shutdownCancel)
+}