Merge pull request #423 from iamqizhao/master

Cancel the contexts of all active streams when a server connection is closed
diff --git a/transport/http2_server.go b/transport/http2_server.go
index f3488f8..ceb3055 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -163,22 +163,6 @@
 	if !endHeaders {
 		return s
 	}
-	t.mu.Lock()
-	if t.state != reachable {
-		t.mu.Unlock()
-		return nil
-	}
-	if uint32(len(t.activeStreams)) >= t.maxStreams {
-		t.mu.Unlock()
-		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
-		return nil
-	}
-	s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
-	t.activeStreams[s.id] = s
-	t.mu.Unlock()
-	s.windowHandler = func(n int) {
-		t.updateWindow(s, uint32(n))
-	}
 	if hDec.state.timeoutSet {
 		s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
 	} else {
@@ -202,6 +186,22 @@
 		recv: s.buf,
 	}
 	s.method = hDec.state.method
+	t.mu.Lock()
+	if t.state != reachable {
+		t.mu.Unlock()
+		return nil
+	}
+	if uint32(len(t.activeStreams)) >= t.maxStreams {
+		t.mu.Unlock()
+		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
+		return nil
+	}
+	s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
+	t.activeStreams[s.id] = s
+	t.mu.Unlock()
+	s.windowHandler = func(n int) {
+		t.updateWindow(s, uint32(n))
+	}
 	handle(s)
 	return nil
 }
@@ -660,9 +660,9 @@
 	t.mu.Unlock()
 	close(t.shutdownChan)
 	err = t.conn.Close()
-	// Notify all active streams.
+	// Cancel all active streams.
 	for _, s := range streams {
-		s.write(recvMsg{err: ErrConnClosing})
+		s.cancel()
 	}
 	return
 }
@@ -684,9 +684,8 @@
 	s.state = streamDone
 	s.mu.Unlock()
 	// In case stream sending and receiving are invoked in separate
-	// goroutines (e.g., bi-directional streaming), the caller needs
-	// to call cancel on the stream to interrupt the blocking on
-	// other goroutines.
+	// goroutines (e.g., bi-directional streaming), cancel needs to be
+	// called to interrupt the potential blocking on other goroutines.
 	s.cancel()
 }
 
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 9bf3ed3..06847ce 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -86,11 +86,11 @@
 	}
 	p := make([]byte, len(req))
 	_, err := io.ReadFull(s, p)
-	if err != nil || !bytes.Equal(p, req) {
-		if err == ErrConnClosing {
-			return
-		}
-		t.Fatalf("handleStream got error: %v, want <nil>; result: %v, want %v", err, p, req)
+	if err != nil {
+		return
+	}
+	if !bytes.Equal(p, req) {
+		t.Fatalf("handleStream got %v, want %v", p, req)
 	}
 	// send a response back to the client.
 	h.t.Write(s, resp, &Options{})
@@ -429,6 +429,69 @@
 	server.stop()
 }
 
+func TestServerContextCanceledOnClosedConnection(t *testing.T) {
+	server, ct := setUp(t, 0, math.MaxUint32, suspended)
+	callHdr := &CallHdr{
+		Host:   "localhost",
+		Method: "foo",
+	}
+	var sc *http2Server
+	// Wait until the server transport is setup.
+	for {
+		server.mu.Lock()
+		if len(server.conns) == 0 {
+			server.mu.Unlock()
+			time.Sleep(time.Millisecond)
+			continue
+		}
+		for k := range server.conns {
+			var ok bool
+			sc, ok = k.(*http2Server)
+			if !ok {
+				t.Fatalf("Failed to convert %v to *http2Server", k)
+			}
+		}
+		server.mu.Unlock()
+		break
+	}
+	cc, ok := ct.(*http2Client)
+	if !ok {
+		t.Fatalf("Failed to convert %v to *http2Client", ct)
+	}
+	s, err := ct.NewStream(context.Background(), callHdr)
+	if err != nil {
+		t.Fatalf("Failed to open stream: %v", err)
+	}
+	// Make sure the headers frame is flushed out.
+	<-cc.writableChan
+	if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil {
+		t.Fatalf("Failed to write data: %v", err)
+	}
+	cc.writableChan <- 0
+	// Loop until the server side stream is created.
+	var ss *Stream
+	for {
+		time.Sleep(time.Second)
+		sc.mu.Lock()
+		if len(sc.activeStreams) == 0 {
+			sc.mu.Unlock()
+			continue
+		}
+		ss = sc.activeStreams[s.id]
+		sc.mu.Unlock()
+		break
+	}
+	cc.Close()
+	select {
+		case <-ss.Context().Done():
+			if ss.Context().Err() != context.Canceled {
+				t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
+			}
+		case <-time.After(5 * time.Second):
+			t.Fatalf("Failed to cancel the context of the sever side stream.")
+	}
+}
+
 func TestServerWithMisbehavedClient(t *testing.T) {
 	server, ct := setUp(t, 0, math.MaxUint32, suspended)
 	callHdr := &CallHdr{