Merge pull request #411 from iamqizhao/master
Improve rpc cancellation when there is no pending I/O
diff --git a/stream.go b/stream.go
index e72cd3d..2370dd0 100644
--- a/stream.go
+++ b/stream.go
@@ -130,6 +130,12 @@
cs.t = t
cs.s = s
cs.p = &parser{s: s}
+ // Listen on ctx.Done() to detect cancellation when there is no pending
+ // I/O operations on this stream.
+ go func() {
+ <-s.Context().Done()
+ cs.closeTransportStream(transport.ContextErr(s.Context().Err()))
+ }()
return cs, nil
}
@@ -143,7 +149,8 @@
tracing bool // set to EnableTracing when the clientStream is created.
- mu sync.Mutex // protects trInfo.tr
+ mu sync.Mutex
+ closed bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
trInfo traceInfo
@@ -157,7 +164,7 @@
m, err := cs.s.Header()
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
- cs.t.CloseStream(cs.s, err)
+ cs.closeTransportStream(err)
}
}
return m, err
@@ -180,7 +187,7 @@
return
}
if _, ok := err.(transport.ConnectionError); !ok {
- cs.t.CloseStream(cs.s, err)
+ cs.closeTransportStream(err)
}
err = toRPCErr(err)
}()
@@ -212,7 +219,7 @@
}
// Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, m)
- cs.t.CloseStream(cs.s, err)
+ cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
@@ -225,7 +232,7 @@
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); !ok {
- cs.t.CloseStream(cs.s, err)
+ cs.closeTransportStream(err)
}
if err == io.EOF {
if cs.s.StatusCode() == codes.OK {
@@ -243,12 +250,23 @@
return
}
if _, ok := err.(transport.ConnectionError); !ok {
- cs.t.CloseStream(cs.s, err)
+ cs.closeTransportStream(err)
}
err = toRPCErr(err)
return
}
+func (cs *clientStream) closeTransportStream(err error) {
+ cs.mu.Lock()
+ if cs.closed {
+ cs.mu.Unlock()
+ return
+ }
+ cs.closed = true
+ cs.mu.Unlock()
+ cs.t.CloseStream(cs.s, err)
+}
+
func (cs *clientStream) finish(err error) {
if !cs.tracing {
return
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 99b0d3c..c5c2f0c 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -798,6 +798,51 @@
}
}
+func TestCancelNoIO(t *testing.T) {
+ for _, e := range listTestEnv() {
+ testCancelNoIO(t, e)
+ }
+}
+
+func testCancelNoIO(t *testing.T, e env) {
+ // Only allows 1 live stream per server transport.
+ s, cc := setUp(t, nil, 1, "", e)
+ tc := testpb.NewTestServiceClient(cc)
+ defer tearDown(s, cc)
+ ctx, cancel := context.WithCancel(context.Background())
+ _, err := tc.StreamingInputCall(ctx)
+ if err != nil {
+ t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+ }
+ // Loop until receiving the new max stream setting from the server.
+ for {
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ _, err := tc.StreamingInputCall(ctx)
+ if err == nil {
+ time.Sleep(time.Second)
+ continue
+ }
+ if grpc.Code(err) == codes.DeadlineExceeded {
+ break
+ }
+ t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
+ }
+ // If there are any RPCs slipping before the client receives the max streams setting,
+ // let them be expired.
+ time.Sleep(2 * time.Second)
+ ch := make(chan struct{})
+ go func() {
+ defer close(ch)
+ // This should be blocked until the 1st is canceled.
+ ctx, _ := context.WithTimeout(context.Background(), 2 * time.Second)
+ if _, err := tc.StreamingInputCall(ctx); err != nil {
+ t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+ }
+ }()
+ cancel();
+ <-ch
+}
+
// The following tests the gRPC streaming RPC implementations.
// TODO(zhaoq): Have better coverage on error cases.
var (