Merge pull request #411 from iamqizhao/master

Improve rpc cancellation when there is no pending I/O
diff --git a/stream.go b/stream.go
index e72cd3d..2370dd0 100644
--- a/stream.go
+++ b/stream.go
@@ -130,6 +130,12 @@
 	cs.t = t
 	cs.s = s
 	cs.p = &parser{s: s}
+	// Listen on ctx.Done() to detect cancellation when there is no pending
+	// I/O operations on this stream.
+	go func() {
+		<-s.Context().Done()
+		cs.closeTransportStream(transport.ContextErr(s.Context().Err()))
+	}()
 	return cs, nil
 }
 
@@ -143,7 +149,8 @@
 
 	tracing bool // set to EnableTracing when the clientStream is created.
 
-	mu sync.Mutex // protects trInfo.tr
+	mu sync.Mutex
+	closed bool
 	// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
 	// and is set to nil when the clientStream's finish method is called.
 	trInfo traceInfo
@@ -157,7 +164,7 @@
 	m, err := cs.s.Header()
 	if err != nil {
 		if _, ok := err.(transport.ConnectionError); !ok {
-			cs.t.CloseStream(cs.s, err)
+			cs.closeTransportStream(err)
 		}
 	}
 	return m, err
@@ -180,7 +187,7 @@
 			return
 		}
 		if _, ok := err.(transport.ConnectionError); !ok {
-			cs.t.CloseStream(cs.s, err)
+			cs.closeTransportStream(err)
 		}
 		err = toRPCErr(err)
 	}()
@@ -212,7 +219,7 @@
 		}
 		// Special handling for client streaming rpc.
 		err = recv(cs.p, cs.codec, m)
-		cs.t.CloseStream(cs.s, err)
+		cs.closeTransportStream(err)
 		if err == nil {
 			return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
 		}
@@ -225,7 +232,7 @@
 		return toRPCErr(err)
 	}
 	if _, ok := err.(transport.ConnectionError); !ok {
-		cs.t.CloseStream(cs.s, err)
+		cs.closeTransportStream(err)
 	}
 	if err == io.EOF {
 		if cs.s.StatusCode() == codes.OK {
@@ -243,12 +250,23 @@
 		return
 	}
 	if _, ok := err.(transport.ConnectionError); !ok {
-		cs.t.CloseStream(cs.s, err)
+		cs.closeTransportStream(err)
 	}
 	err = toRPCErr(err)
 	return
 }
 
+func (cs *clientStream) closeTransportStream(err error) {
+	cs.mu.Lock()
+	if cs.closed {
+		cs.mu.Unlock()
+		return
+	}
+	cs.closed = true
+	cs.mu.Unlock()
+	cs.t.CloseStream(cs.s, err)
+}
+
 func (cs *clientStream) finish(err error) {
 	if !cs.tracing {
 		return
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 99b0d3c..c5c2f0c 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -798,6 +798,51 @@
 	}
 }
 
+func TestCancelNoIO(t *testing.T) {
+	for _, e := range listTestEnv() {
+		testCancelNoIO(t, e)
+	}
+}
+
+func testCancelNoIO(t *testing.T, e env) {
+	// Only allows 1 live stream per server transport.
+	s, cc := setUp(t, nil, 1, "", e)
+	tc := testpb.NewTestServiceClient(cc)
+	defer tearDown(s, cc)
+	ctx, cancel := context.WithCancel(context.Background())
+	_, err := tc.StreamingInputCall(ctx)
+	if err != nil {
+		t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+	}
+	// Loop until receiving the new max stream setting from the server.
+	for {
+		ctx, _ := context.WithTimeout(context.Background(), time.Second)
+		_, err := tc.StreamingInputCall(ctx)
+		if err == nil {
+			time.Sleep(time.Second)
+			continue
+		}
+		if grpc.Code(err) == codes.DeadlineExceeded {
+			break
+		}
+		t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
+	}
+	// If there are any RPCs slipping before the client receives the max streams setting,
+	// let them be expired.
+	time.Sleep(2 * time.Second)
+	ch := make(chan struct{})
+	go func() {
+		defer close(ch)
+		// This should be blocked until the 1st is canceled.
+		ctx, _ := context.WithTimeout(context.Background(), 2 * time.Second)
+		if _, err := tc.StreamingInputCall(ctx); err != nil {
+			t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+		}
+	}()
+	cancel();
+	<-ch
+}
+
 // The following tests the gRPC streaming RPC implementations.
 // TODO(zhaoq): Have better coverage on error cases.
 var (