Merge pull request #423 from iamqizhao/master
Cancel the contexts of all active streams when a server connection is closed
diff --git a/transport/http2_server.go b/transport/http2_server.go
index f3488f8..ceb3055 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -163,22 +163,6 @@
if !endHeaders {
return s
}
- t.mu.Lock()
- if t.state != reachable {
- t.mu.Unlock()
- return nil
- }
- if uint32(len(t.activeStreams)) >= t.maxStreams {
- t.mu.Unlock()
- t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
- return nil
- }
- s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
- t.activeStreams[s.id] = s
- t.mu.Unlock()
- s.windowHandler = func(n int) {
- t.updateWindow(s, uint32(n))
- }
if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
} else {
@@ -202,6 +186,22 @@
recv: s.buf,
}
s.method = hDec.state.method
+ t.mu.Lock()
+ if t.state != reachable {
+ t.mu.Unlock()
+ return nil
+ }
+ if uint32(len(t.activeStreams)) >= t.maxStreams {
+ t.mu.Unlock()
+ t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
+ return nil
+ }
+ s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
+ t.activeStreams[s.id] = s
+ t.mu.Unlock()
+ s.windowHandler = func(n int) {
+ t.updateWindow(s, uint32(n))
+ }
handle(s)
return nil
}
@@ -660,9 +660,9 @@
t.mu.Unlock()
close(t.shutdownChan)
err = t.conn.Close()
- // Notify all active streams.
+ // Cancel all active streams.
for _, s := range streams {
- s.write(recvMsg{err: ErrConnClosing})
+ s.cancel()
}
return
}
@@ -684,9 +684,8 @@
s.state = streamDone
s.mu.Unlock()
// In case stream sending and receiving are invoked in separate
- // goroutines (e.g., bi-directional streaming), the caller needs
- // to call cancel on the stream to interrupt the blocking on
- // other goroutines.
+ // goroutines (e.g., bi-directional streaming), cancel needs to be
+ // called to interrupt the potential blocking on other goroutines.
s.cancel()
}
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 9bf3ed3..06847ce 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -86,11 +86,11 @@
}
p := make([]byte, len(req))
_, err := io.ReadFull(s, p)
- if err != nil || !bytes.Equal(p, req) {
- if err == ErrConnClosing {
- return
- }
- t.Fatalf("handleStream got error: %v, want <nil>; result: %v, want %v", err, p, req)
+ if err != nil {
+ return
+ }
+ if !bytes.Equal(p, req) {
+ t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
h.t.Write(s, resp, &Options{})
@@ -429,6 +429,69 @@
server.stop()
}
+func TestServerContextCanceledOnClosedConnection(t *testing.T) {
+ server, ct := setUp(t, 0, math.MaxUint32, suspended)
+ callHdr := &CallHdr{
+ Host: "localhost",
+ Method: "foo",
+ }
+ var sc *http2Server
+ // Wait until the server transport is setup.
+ for {
+ server.mu.Lock()
+ if len(server.conns) == 0 {
+ server.mu.Unlock()
+ time.Sleep(time.Millisecond)
+ continue
+ }
+ for k := range server.conns {
+ var ok bool
+ sc, ok = k.(*http2Server)
+ if !ok {
+ t.Fatalf("Failed to convert %v to *http2Server", k)
+ }
+ }
+ server.mu.Unlock()
+ break
+ }
+ cc, ok := ct.(*http2Client)
+ if !ok {
+ t.Fatalf("Failed to convert %v to *http2Client", ct)
+ }
+ s, err := ct.NewStream(context.Background(), callHdr)
+ if err != nil {
+ t.Fatalf("Failed to open stream: %v", err)
+ }
+ // Make sure the headers frame is flushed out.
+ <-cc.writableChan
+ if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil {
+ t.Fatalf("Failed to write data: %v", err)
+ }
+ cc.writableChan <- 0
+ // Loop until the server side stream is created.
+ var ss *Stream
+ for {
+ time.Sleep(time.Second)
+ sc.mu.Lock()
+ if len(sc.activeStreams) == 0 {
+ sc.mu.Unlock()
+ continue
+ }
+ ss = sc.activeStreams[s.id]
+ sc.mu.Unlock()
+ break
+ }
+ cc.Close()
+ select {
+ case <-ss.Context().Done():
+ if ss.Context().Err() != context.Canceled {
+ t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Failed to cancel the context of the sever side stream.")
+ }
+}
+
func TestServerWithMisbehavedClient(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{