Merge pull request #389 from iamqizhao/master

remove sync.WaitGroup param from ServerTransport.HandleStream
diff --git a/server.go b/server.go
index 5c20dd4..487a75c 100644
--- a/server.go
+++ b/server.go
@@ -259,7 +259,8 @@
 		s.mu.Unlock()
 
 		go func() {
-			st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) {
+			var wg sync.WaitGroup
+			st.HandleStreams(func(stream *transport.Stream) {
 				var trInfo *traceInfo
 				if EnableTracing {
 					trInfo = &traceInfo{
@@ -278,6 +279,7 @@
 					wg.Done()
 				}()
 			})
+			wg.Wait()
 			s.mu.Lock()
 			delete(s.conns, st)
 			s.mu.Unlock()
diff --git a/transport/http2_server.go b/transport/http2_server.go
index ed8fde0..52d0ee5 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -138,7 +138,7 @@
 // operateHeader takes action on the decoded headers. It returns the current
 // stream if there are remaining headers on the wire (in the following
 // Continuation frame).
-func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) {
+func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) {
 	defer func() {
 		if pendingStream == nil {
 			hDec.state = decodeState{}
@@ -202,13 +202,13 @@
 		recv: s.buf,
 	}
 	s.method = hDec.state.method
-	handle(s, wg)
+	handle(s)
 	return nil
 }
 
 // HandleStreams receives incoming streams using the given handler. This is
 // typically run in a separate goroutine.
-func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
+func (t *http2Server) HandleStreams(handle func(*Stream)) {
 	// Check the validity of client preface.
 	preface := make([]byte, len(clientPreface))
 	if _, err := io.ReadFull(t.conn, preface); err != nil {
@@ -238,8 +238,6 @@
 
 	hDec := newHPACKDecoder()
 	var curStream *Stream
-	var wg sync.WaitGroup
-	defer wg.Wait()
 	for {
 		frame, err := t.framer.readFrame()
 		if err != nil {
@@ -268,9 +266,9 @@
 				fc:  fc,
 			}
 			endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
-			curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg)
+			curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle)
 		case *http2.ContinuationFrame:
-			curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg)
+			curStream = t.operateHeaders(hDec, curStream, frame, false, handle)
 		case *http2.DataFrame:
 			t.handleData(frame)
 		case *http2.RSTStreamFrame:
diff --git a/transport/transport.go b/transport/transport.go
index c319a5f..e1e7f57 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -391,7 +391,7 @@
 	// WriteHeader sends the header metedata for the given stream.
 	WriteHeader(s *Stream, md metadata.MD) error
 	// HandleStreams receives incoming streams using the given handler.
-	HandleStreams(func(*Stream, *sync.WaitGroup))
+	HandleStreams(func(*Stream))
 	// Close tears down the transport. Once it is called, the transport
 	// should not be accessed any more. All the pending streams and their
 	// handlers will be terminated asynchronously.
diff --git a/transport/transport_test.go b/transport/transport_test.go
index ba1d66a..9bf3ed3 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -77,8 +77,7 @@
 	misbehaved
 )
 
-func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) {
-	defer wg.Done()
+func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
 	req := expectedRequest
 	resp := expectedResponse
 	if s.Method() == "foo.Large" {
@@ -100,16 +99,13 @@
 }
 
 // handleStreamSuspension blocks until s.ctx is canceled.
-func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) {
-	wg.Add(1)
+func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
 	go func() {
 		<-s.ctx.Done()
-		wg.Done()
 	}()
 }
 
-func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) {
-	defer wg.Done()
+func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
 	conn, ok := s.ServerTransport().(*http2Server)
 	if !ok {
 		t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
@@ -173,14 +169,12 @@
 		case suspended:
 			go transport.HandleStreams(h.handleStreamSuspension)
 		case misbehaved:
-			go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
-				wg.Add(1)
-				go h.handleStreamMisbehave(t, s, wg)
+			go transport.HandleStreams(func(s *Stream) {
+				go h.handleStreamMisbehave(t, s)
 			})
 		default:
-			go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
-				wg.Add(1)
-				go h.handleStream(t, s, wg)
+			go transport.HandleStreams(func(s *Stream) {
+				go h.handleStream(t, s)
 			})
 		}
 	}