server: break up the Server.Serve method into some reusable parts
Updates grpc/grpc-go#75
diff --git a/server.go b/server.go
index dd86427..f6ee266 100644
--- a/server.go
+++ b/server.go
@@ -264,49 +264,74 @@
}
s.mu.Unlock()
+ go s.serveNewHTTP2Transport(c, authInfo)
+ }
+}
+
+func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
+ st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
+ if err != nil {
+ s.mu.Lock()
+ s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
+ s.mu.Unlock()
+ c.Close()
+ grpclog.Println("grpc: Server.Serve failed to create ServerTransport: ", err)
+ return
+ }
+ if !s.addConn(st) {
+ c.Close()
+ return
+ }
+ s.serveStreams(st)
+}
+
+func (s *Server) serveStreams(st transport.ServerTransport) {
+ defer s.removeConn(st)
+ defer st.Close()
+ var wg sync.WaitGroup
+ st.HandleStreams(func(stream *transport.Stream) {
+ wg.Add(1)
go func() {
- st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
- if err != nil {
- s.mu.Lock()
- s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
- s.mu.Unlock()
- c.Close()
- grpclog.Println("grpc: Server.Serve failed to create ServerTransport: ", err)
- return
- }
- defer st.Close()
- s.mu.Lock()
- if s.conns == nil {
- s.mu.Unlock()
- return
- }
- s.conns[st] = true
- s.mu.Unlock()
- var wg sync.WaitGroup
- st.HandleStreams(func(stream *transport.Stream) {
- var trInfo *traceInfo
- if EnableTracing {
- trInfo = &traceInfo{
- tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()),
- }
- trInfo.firstLine.client = false
- trInfo.firstLine.remoteAddr = st.RemoteAddr()
- stream.TraceContext(trInfo.tr)
- if dl, ok := stream.Context().Deadline(); ok {
- trInfo.firstLine.deadline = dl.Sub(time.Now())
- }
- }
- wg.Add(1)
- go func() {
- s.handleStream(st, stream, trInfo)
- wg.Done()
- }()
- })
- wg.Wait()
- s.mu.Lock()
- delete(s.conns, st)
- s.mu.Unlock()
+ defer wg.Done()
+ s.handleStream(st, stream, s.traceInfo(st, stream))
}()
+ })
+ wg.Wait()
+}
+
+// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
+// If tracing is not enabled, it returns nil.
+func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
+ if !EnableTracing {
+ return nil
+ }
+ trInfo = &traceInfo{
+ tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()),
+ }
+ trInfo.firstLine.client = false
+ trInfo.firstLine.remoteAddr = st.RemoteAddr()
+ stream.TraceContext(trInfo.tr)
+ if dl, ok := stream.Context().Deadline(); ok {
+ trInfo.firstLine.deadline = dl.Sub(time.Now())
+ }
+ return trInfo
+}
+
+func (s *Server) addConn(st transport.ServerTransport) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.conns == nil {
+ return false
+ }
+ s.conns[st] = true
+ return true
+}
+
+func (s *Server) removeConn(st transport.ServerTransport) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.conns != nil {
+ delete(s.conns, st)
}
}