Merge pull request #552 from bradfitz/concurrency
Fix flakiness of TestCancelNoIO with http.Handler-based server transport
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 0bb7289..e19333b 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -34,6 +34,7 @@
package grpc_test
import (
+ "bytes"
"flag"
"fmt"
"io"
@@ -374,9 +375,71 @@
return envs
}
+// serverSetUp is the old way to start a test server. New callers should use newTest.
+// TODO(bradfitz): update all tests to newTest and delete this.
func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
- t.Logf("Running test in %s environment...", e.name)
- sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)}
+ te := &test{
+ t: t,
+ e: e,
+ healthServer: hs,
+ maxStream: maxStream,
+ cp: cp,
+ dc: dc,
+ }
+ if servON {
+ te.testServer = &testServer{security: e.security}
+ }
+ te.startServer()
+ return te.srv, te.srvAddr
+}
+
+// test is an end-to-end test. It should be created with the newTest
+// func, modified as needed, and then started with its startServer method.
+// It should be cleaned up with the tearDown method.
+type test struct {
+ t *testing.T
+ e env
+
+ // Configurable knobs, after newTest returns:
+ testServer testpb.TestServiceServer // nil means none
+ healthServer *health.HealthServer // nil means disabled
+ maxStream uint32
+ cp grpc.Compressor // nil means no server compression
+ dc grpc.Decompressor // nil means no server decompression
+ userAgent string
+
+ // srv and srvAddr are set once startServer is called.
+ srv *grpc.Server
+ srvAddr string
+
+ cc *grpc.ClientConn // nil until requested via clientConn
+}
+
+func (te *test) tearDown() {
+ te.srv.Stop()
+ if te.cc != nil {
+ te.cc.Close()
+ }
+}
+
+// newTest returns a new test using the provided testing.T and
+// environment. It is returned with default values. Tests should
+// modify it before calling its startServer and clientConn methods.
+func newTest(t *testing.T, e env) *test {
+ return &test{
+ t: t,
+ e: e,
+ testServer: &testServer{security: e.security},
+ maxStream: math.MaxUint32,
+ }
+}
+
+// startServer starts a gRPC server listening. Callers should defer a
+// call to te.tearDown to clean up.
+func (te *test) startServer() {
+ e := te.e
+ te.t.Logf("Running test in %s environment...", e.name)
+ sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream), grpc.RPCCompressor(te.cp), grpc.RPCDecompressor(te.dc)}
la := ":0"
switch e.network {
case "unix":
@@ -385,37 +448,46 @@
}
lis, err := net.Listen(e.network, la)
if err != nil {
- t.Fatalf("Failed to listen: %v", err)
+ te.t.Fatalf("Failed to listen: %v", err)
}
if e.security == "tls" {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
- t.Fatalf("Failed to generate credentials %v", err)
+ te.t.Fatalf("Failed to generate credentials %v", err)
}
sopts = append(sopts, grpc.Creds(creds))
}
- s = grpc.NewServer(sopts...)
+ s := grpc.NewServer(sopts...)
+ te.srv = s
if e.httpHandler {
s.TestingUseHandlerImpl()
}
- if hs != nil {
- healthpb.RegisterHealthServer(s, hs)
+ if te.healthServer != nil {
+ healthpb.RegisterHealthServer(s, te.healthServer)
}
- if servON {
- testpb.RegisterTestServiceServer(s, &testServer{security: e.security})
+ if te.testServer != nil {
+ testpb.RegisterTestServiceServer(s, te.testServer)
}
- go s.Serve(lis)
- addr = la
+ addr := la
switch e.network {
case "unix":
default:
_, port, err := net.SplitHostPort(lis.Addr().String())
if err != nil {
- t.Fatalf("Failed to parse listener address: %v", err)
+ te.t.Fatalf("Failed to parse listener address: %v", err)
}
addr = "localhost:" + port
}
- return
+
+ go s.Serve(lis)
+ te.srvAddr = addr
+}
+
+func (te *test) clientConn() *grpc.ClientConn {
+ if te.cc == nil {
+ te.cc = clientSetUp(te.t, te.srvAddr, te.cp, te.dc, te.userAgent, te.e)
+ }
+ return te.cc
}
func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) {
@@ -888,17 +960,28 @@
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
- ctx, cancel := context.WithCancel(context.Background())
+
+ // Start one blocked RPC for which we'll never send streaming
+ // input. This will consume the 1 maximum concurrent streams,
+ // causing future RPCs to hang.
+ ctx, cancelFirst := context.WithCancel(context.Background())
_, err := tc.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
- // Loop until receiving the new max stream setting from the server.
+
+ // Loop until the ClientConn receives the initial settings
+ // frame from the server, notifying it about the maximum
+ // concurrent streams. We know when it's received it because
+ // an RPC will fail with codes.DeadlineExceeded instead of
+ // succeeding.
+ // TODO(bradfitz): add internal test hook for this (Issue 534)
for {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancelSecond := context.WithTimeout(context.Background(), 250*time.Millisecond)
_, err := tc.StreamingInputCall(ctx)
+ cancelSecond()
if err == nil {
- time.Sleep(time.Second)
+ time.Sleep(50 * time.Millisecond)
continue
}
if grpc.Code(err) == codes.DeadlineExceeded {
@@ -906,19 +989,23 @@
}
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
}
- // If there are any RPCs slipping before the client receives the max streams setting,
- // let them be expired.
- time.Sleep(2 * time.Second)
+ // If there are any RPCs in flight before the client receives
+ // the max streams setting, let them be expired.
+ // TODO(bradfitz): add internal test hook for this (Issue 534)
+ time.Sleep(500 * time.Millisecond)
+
ch := make(chan struct{})
go func() {
defer close(ch)
+
// This should be blocked until the 1st is canceled.
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancelThird := context.WithTimeout(context.Background(), 2*time.Second)
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
+ cancelThird()
}()
- cancel()
+ cancelFirst()
<-ch
}
@@ -1169,6 +1256,87 @@
}
}
+// concurrentSendServer is a TestServiceServer whose
+// StreamingOutputCall makes ten serial Send calls, sending payloads
+// "0".."9", inclusive. TestServerStreaming_Concurrent verifies they
+// were received in the correct order, and that there were no races.
+//
+// All other TestServiceServer methods crash if called.
+type concurrentSendServer struct {
+ testpb.TestServiceServer
+}
+
+func (s concurrentSendServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
+ for i := 0; i < 10; i++ {
+ stream.Send(&testpb.StreamingOutputCallResponse{
+ Payload: &testpb.Payload{
+ Body: []byte{'0' + uint8(i)},
+ },
+ })
+ }
+ return nil
+}
+
+// Tests doing a bunch of concurrent streaming output calls.
+func TestServerStreaming_Concurrent(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testServerStreaming_Concurrent(t, e)
+ }
+}
+
+func testServerStreaming_Concurrent(t *testing.T, e env) {
+ et := newTest(t, e)
+ et.testServer = concurrentSendServer{}
+ et.startServer()
+ defer et.tearDown()
+
+ cc := et.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+
+ doStreamingCall := func() {
+ req := &testpb.StreamingOutputCallRequest{}
+ stream, err := tc.StreamingOutputCall(context.Background(), req)
+ if err != nil {
+ t.Errorf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
+ return
+ }
+ var ngot int
+ var buf bytes.Buffer
+ for {
+ reply, err := stream.Recv()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ ngot++
+ if buf.Len() > 0 {
+ buf.WriteByte(',')
+ }
+ buf.Write(reply.GetPayload().GetBody())
+ }
+ if want := 10; ngot != want {
+ t.Errorf("Got %d replies, want %d", ngot, want)
+ }
+ if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
+ t.Errorf("Got replies %q; want %q", got, want)
+ }
+ }
+
+ var wg sync.WaitGroup
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doStreamingCall()
+ }()
+ }
+ wg.Wait()
+
+}
+
func TestClientStreaming(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
diff --git a/transport/handler_server.go b/transport/handler_server.go
index 5d1bffe..8bfbf97 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -75,10 +75,10 @@
}
st := &serverHandlerTransport{
- rw: w,
- req: r,
- closedCh: make(chan struct{}),
- wroteStatus: make(chan struct{}),
+ rw: w,
+ req: r,
+ closedCh: make(chan struct{}),
+ writes: make(chan func()),
}
if v := r.Header.Get("grpc-timeout"); v != "" {
@@ -132,7 +132,10 @@
closeOnce sync.Once
closedCh chan struct{} // closed on Close
- wroteStatus chan struct{} // closed on WriteStatus
+ // writes is a channel of code to run serialized in the
+ // ServeHTTP (HandleStreams) goroutine. The channel is closed
+ // when WriteStatus is called.
+ writes chan func()
}
func (ht *serverHandlerTransport) Close() error {
@@ -166,31 +169,43 @@
func (a strAddr) String() string { return string(a) }
-func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
- ht.writeCommonHeaders(s)
-
- // And flush, in case no header or body has been sent yet.
- // This forces a separation of headers and trailers if this is the
- // first call (for example, in end2end tests's TestNoService).
- ht.rw.(http.Flusher).Flush()
-
- h := ht.rw.Header()
- h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
- if statusDesc != "" {
- h.Set("Grpc-Message", statusDesc)
+// do runs fn in the ServeHTTP goroutine.
+func (ht *serverHandlerTransport) do(fn func()) error {
+ select {
+ case ht.writes <- fn:
+ return nil
+ case <-ht.closedCh:
+ return ErrConnClosing
}
- if md := s.Trailer(); len(md) > 0 {
- for k, vv := range md {
- for _, v := range vv {
- // http2 ResponseWriter mechanism to
- // send undeclared Trailers after the
- // headers have possibly been written.
- h.Add(http2.TrailerPrefix+k, v)
+}
+
+func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
+ err := ht.do(func() {
+ ht.writeCommonHeaders(s)
+
+ // And flush, in case no header or body has been sent yet.
+ // This forces a separation of headers and trailers if this is the
+ // first call (for example, in end2end tests's TestNoService).
+ ht.rw.(http.Flusher).Flush()
+
+ h := ht.rw.Header()
+ h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
+ if statusDesc != "" {
+ h.Set("Grpc-Message", statusDesc)
+ }
+ if md := s.Trailer(); len(md) > 0 {
+ for k, vv := range md {
+ for _, v := range vv {
+ // http2 ResponseWriter mechanism to
+ // send undeclared Trailers after the
+ // headers have possibly been written.
+ h.Add(http2.TrailerPrefix+k, v)
+ }
}
}
- }
- close(ht.wroteStatus)
- return nil
+ })
+ close(ht.writes)
+ return err
}
// writeCommonHeaders sets common headers on the first write
@@ -219,28 +234,30 @@
}
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
- ht.writeCommonHeaders(s)
- ht.rw.Write(data)
- if !opts.Delay {
- ht.rw.(http.Flusher).Flush()
- }
- return nil
+ return ht.do(func() {
+ ht.writeCommonHeaders(s)
+ ht.rw.Write(data)
+ if !opts.Delay {
+ ht.rw.(http.Flusher).Flush()
+ }
+ })
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
- ht.writeCommonHeaders(s)
- h := ht.rw.Header()
- for k, vv := range md {
- for _, v := range vv {
- h.Add(k, v)
+ return ht.do(func() {
+ ht.writeCommonHeaders(s)
+ h := ht.rw.Header()
+ for k, vv := range md {
+ for _, v := range vv {
+ h.Add(k, v)
+ }
}
- }
- ht.rw.WriteHeader(200)
- ht.rw.(http.Flusher).Flush()
- return nil
+ ht.rw.WriteHeader(200)
+ ht.rw.(http.Flusher).Flush()
+ })
}
-func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
+func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var ctx context.Context
@@ -251,12 +268,18 @@
ctx, cancel = context.WithCancel(context.Background())
}
+ // requestOver is closed when either the request's context is done
+ // or the status has been written via WriteStatus.
+ requestOver := make(chan struct{})
+
// clientGone receives a single value if peer is gone, either
// because the underlying connection is dead or because the
// peer sends an http2 RST_STREAM.
clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
go func() {
select {
+ case <-requestOver:
+ return
case <-ht.closedCh:
case <-clientGone:
}
@@ -285,10 +308,6 @@
s.ctx = newContextWithStream(ctx, s)
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
- // requestOver is closed when either the request's context is done
- // or the status has been written via WriteStatus.
- requestOver := make(chan struct{})
-
// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})
go func() {
@@ -296,34 +315,40 @@
for {
buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
n, err := req.Body.Read(buf)
- select {
- case <-requestOver:
- return
- default:
- }
if n > 0 {
s.buf.put(&recvMsg{data: buf[:n]})
}
if err != nil {
s.buf.put(&recvMsg{err: err})
- break
+ return
}
}
}()
- // runStream is provided by the *grpc.Server.serveStreams.
- // It starts a goroutine handling s and exits immediately.
- runStream(s)
+ // startStream is provided by the *grpc.Server's serveStreams.
+ // It starts a goroutine serving s and exits immediately.
+ // The goroutine that is started is the one that then calls
+ // into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
+ startStream(s)
- // Wait for the stream to be done. It is considered done when
- // either its context is done, or we've written its status.
- select {
- case <-ctx.Done():
- case <-ht.wroteStatus:
- }
+ ht.runStream()
close(requestOver)
// Wait for reading goroutine to finish.
req.Body.Close()
<-readerDone
}
+
+func (ht *serverHandlerTransport) runStream() {
+ for {
+ select {
+ case fn, ok := <-ht.writes:
+ if !ok {
+ return
+ }
+ fn()
+ case <-ht.closedCh:
+ return
+ }
+ }
+}
diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go
index faa95af..1fee72f 100644
--- a/transport/handler_server_test.go
+++ b/transport/handler_server_test.go
@@ -293,13 +293,14 @@
func TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t)
- st.ht.HandleStreams(func(s *Stream) {
+ handleStream := func(s *Stream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
st.bodyw.Close() // no body
st.ht.WriteStatus(s, codes.OK, "")
- })
+ }
+ st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@@ -323,9 +324,10 @@
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)
- st.ht.HandleStreams(func(s *Stream) {
+ handleStream := func(s *Stream) {
st.ht.WriteStatus(s, statusCode, msg)
- })
+ }
+ st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@@ -358,7 +360,7 @@
if err != nil {
t.Fatal(err)
}
- ht.HandleStreams(func(s *Stream) {
+ runStream := func(s *Stream) {
defer bodyw.Close()
select {
case <-s.ctx.Done():
@@ -372,7 +374,8 @@
return
}
ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
- })
+ }
+ ht.HandleStreams(func(s *Stream) { go runStream(s) })
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@@ -381,6 +384,6 @@
"Grpc-Message": {"too slow"},
}
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
- t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader)
+ t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
}
}
diff --git a/transport/http2_server.go b/transport/http2_server.go
index cce2e12..b9d4959 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -62,8 +62,8 @@
maxStreamID uint32 // max stream ID ever seen
authInfo credentials.AuthInfo // auth info about the connection
// writableChan synchronizes write access to the transport.
- // A writer acquires the write lock by sending a value on writableChan
- // and releases it by receiving from writableChan.
+ // A writer acquires the write lock by receiving a value on writableChan
+ // and releases it by sending on writableChan.
writableChan chan int
// shutdownChan is closed when Close is called.
// Blocking operations should select on shutdownChan to avoid
diff --git a/transport/transport.go b/transport/transport.go
index 6c3b943..3b934b4 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -352,30 +352,40 @@
// Options provides additional hints and information for message
// transmission.
type Options struct {
- // Indicate whether it is the last piece for this stream.
+ // Last indicates whether this write is the last piece for
+ // this stream.
Last bool
- // The hint to transport impl whether the data could be buffered for
- // batching write. Transport impl can feel free to ignore it.
+
+ // Delay is a hint to the transport implementation for whether
+ // the data could be buffered for a batching write. The
+ // Transport implementation may ignore the hint.
Delay bool
}
// CallHdr carries the information of a particular RPC.
type CallHdr struct {
- // Host specifies peer host.
+ // Host specifies the peer's host.
Host string
+
// Method specifies the operation to perform.
Method string
- // RecvCompress specifies the compression algorithm applied on inbound messages.
+
+ // RecvCompress specifies the compression algorithm applied on
+ // inbound messages.
RecvCompress string
- // SendCompress specifies the compression algorithm applied on outbound message.
+
+ // SendCompress specifies the compression algorithm applied on
+ // outbound message.
SendCompress string
- // Flush indicates if new stream command should be sent to the peer without
- // waiting for the first data. This is a hint though. The transport may modify
- // the flush decision for performance purpose.
+
+ // Flush indicates whether a new stream command should be sent
+ // to the peer without waiting for the first data. This is
+ // only a hint. The transport may modify the flush decision
+ // for performance purposes.
Flush bool
}
-// ClientTransport is the common interface for all gRPC client side transport
+// ClientTransport is the common interface for all gRPC client-side transport
// implementations.
type ClientTransport interface {
// Close tears down this transport. Once it returns, the transport
@@ -404,21 +414,33 @@
Error() <-chan struct{}
}
-// ServerTransport is the common interface for all gRPC server side transport
+// ServerTransport is the common interface for all gRPC server-side transport
// implementations.
+//
+// Methods may be called concurrently from multiple goroutines, but
+// Write methods for a given Stream will be called serially.
type ServerTransport interface {
- // WriteStatus sends the status of a stream to the client.
- WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
- // Write sends the data for the given stream.
- Write(s *Stream, data []byte, opts *Options) error
- // WriteHeader sends the header metedata for the given stream.
- WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream))
+
+ // WriteHeader sends the header metadata for the given stream.
+ // WriteHeader may not be called on all streams.
+ WriteHeader(s *Stream, md metadata.MD) error
+
+ // Write sends the data for the given stream.
+ // Write may not be called on all streams.
+ Write(s *Stream, data []byte, opts *Options) error
+
+ // WriteStatus sends the status of a stream to the client.
+ // WriteStatus is the final call made on a stream and always
+ // occurs.
+ WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
+
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
Close() error
+
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
}