Merge pull request #389 from iamqizhao/master
remove sync.WaitGroup param from ServerTransport.HandleStream
diff --git a/server.go b/server.go
index 5c20dd4..487a75c 100644
--- a/server.go
+++ b/server.go
@@ -259,7 +259,8 @@
s.mu.Unlock()
go func() {
- st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) {
+ var wg sync.WaitGroup
+ st.HandleStreams(func(stream *transport.Stream) {
var trInfo *traceInfo
if EnableTracing {
trInfo = &traceInfo{
@@ -278,6 +279,7 @@
wg.Done()
}()
})
+ wg.Wait()
s.mu.Lock()
delete(s.conns, st)
s.mu.Unlock()
diff --git a/transport/http2_server.go b/transport/http2_server.go
index ed8fde0..52d0ee5 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -138,7 +138,7 @@
// operateHeader takes action on the decoded headers. It returns the current
// stream if there are remaining headers on the wire (in the following
// Continuation frame).
-func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) {
+func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) {
defer func() {
if pendingStream == nil {
hDec.state = decodeState{}
@@ -202,13 +202,13 @@
recv: s.buf,
}
s.method = hDec.state.method
- handle(s, wg)
+ handle(s)
return nil
}
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
-func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
+func (t *http2Server) HandleStreams(handle func(*Stream)) {
// Check the validity of client preface.
preface := make([]byte, len(clientPreface))
if _, err := io.ReadFull(t.conn, preface); err != nil {
@@ -238,8 +238,6 @@
hDec := newHPACKDecoder()
var curStream *Stream
- var wg sync.WaitGroup
- defer wg.Wait()
for {
frame, err := t.framer.readFrame()
if err != nil {
@@ -268,9 +266,9 @@
fc: fc,
}
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
- curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg)
+ curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle)
case *http2.ContinuationFrame:
- curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg)
+ curStream = t.operateHeaders(hDec, curStream, frame, false, handle)
case *http2.DataFrame:
t.handleData(frame)
case *http2.RSTStreamFrame:
diff --git a/transport/transport.go b/transport/transport.go
index c319a5f..e1e7f57 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -391,7 +391,7 @@
// WriteHeader sends the header metedata for the given stream.
WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler.
- HandleStreams(func(*Stream, *sync.WaitGroup))
+ HandleStreams(func(*Stream))
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
diff --git a/transport/transport_test.go b/transport/transport_test.go
index ba1d66a..9bf3ed3 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -77,8 +77,7 @@
misbehaved
)
-func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) {
- defer wg.Done()
+func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
@@ -100,16 +99,13 @@
}
// handleStreamSuspension blocks until s.ctx is canceled.
-func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) {
- wg.Add(1)
+func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
go func() {
<-s.ctx.Done()
- wg.Done()
}()
}
-func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) {
- defer wg.Done()
+func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
conn, ok := s.ServerTransport().(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
@@ -173,14 +169,12 @@
case suspended:
go transport.HandleStreams(h.handleStreamSuspension)
case misbehaved:
- go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
- wg.Add(1)
- go h.handleStreamMisbehave(t, s, wg)
+ go transport.HandleStreams(func(s *Stream) {
+ go h.handleStreamMisbehave(t, s)
})
default:
- go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
- wg.Add(1)
- go h.handleStream(t, s, wg)
+ go transport.HandleStreams(func(s *Stream) {
+ go h.handleStream(t, s)
})
}
}