Merge pull request #387 from iamqizhao/master
Refactor server side tracing
diff --git a/call.go b/call.go
index b5f9292..8b68809 100644
--- a/call.go
+++ b/call.go
@@ -117,7 +117,7 @@
}
}()
if EnableTracing {
- c.traceInfo.tr = trace.New("grpc.Sent."+transport.MethodFamily(method), method)
+ c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
defer c.traceInfo.tr.Finish()
c.traceInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
diff --git a/server.go b/server.go
index a5c8dc3..5c20dd4 100644
--- a/server.go
+++ b/server.go
@@ -247,7 +247,7 @@
c.Close()
return nil
}
- st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo, EnableTracing)
+ st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
if err != nil {
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
s.mu.Unlock()
@@ -259,8 +259,24 @@
s.mu.Unlock()
go func() {
- st.HandleStreams(func(stream *transport.Stream) {
- s.handleStream(st, stream)
+ st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) {
+ var trInfo *traceInfo
+ if EnableTracing {
+ trInfo = &traceInfo{
+ tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()),
+ }
+ trInfo.firstLine.client = false
+ trInfo.firstLine.remoteAddr = st.RemoteAddr()
+ stream.TraceContext(trInfo.tr)
+ if dl, ok := stream.Context().Deadline(); ok {
+ trInfo.firstLine.deadline = dl.Sub(time.Now())
+ }
+ }
+ wg.Add(1)
+ go func() {
+ s.handleStream(st, stream, trInfo)
+ wg.Done()
+ }()
})
s.mu.Lock()
delete(s.conns, st)
@@ -284,21 +300,15 @@
return t.Write(stream, p, opts)
}
-func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc) (err error) {
- var traceInfo traceInfo
- if EnableTracing {
- traceInfo.tr = stream.Trace()
- defer traceInfo.tr.Finish()
- traceInfo.firstLine.client = false
- traceInfo.firstLine.remoteAddr = t.RemoteAddr()
- if dl, ok := stream.Context().Deadline(); ok {
- traceInfo.firstLine.deadline = dl.Sub(time.Now())
- }
- traceInfo.tr.LazyLog(&traceInfo.firstLine, false)
+func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
+ if trInfo != nil {
+ defer trInfo.tr.Finish()
+ trInfo.firstLine.client = false
+ trInfo.tr.LazyLog(&trInfo.firstLine, false)
defer func() {
if err != nil && err != io.EOF {
- traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
- traceInfo.tr.SetError()
+ trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
+ trInfo.tr.SetError()
}
}()
}
@@ -330,8 +340,8 @@
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
- if traceInfo.tr != nil {
- traceInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
+ if trInfo != nil {
+ trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
}
return nil
}
@@ -344,9 +354,9 @@
statusCode = convertCode(appErr)
statusDesc = appErr.Error()
}
- if traceInfo.tr != nil && statusCode != codes.OK {
- traceInfo.tr.LazyLog(stringer(statusDesc), true)
- traceInfo.tr.SetError()
+ if trInfo != nil && statusCode != codes.OK {
+ trInfo.tr.LazyLog(stringer(statusDesc), true)
+ trInfo.tr.SetError()
}
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
@@ -355,8 +365,8 @@
}
return nil
}
- if traceInfo.tr != nil {
- traceInfo.tr.LazyLog(stringer("OK"), false)
+ if trInfo != nil {
+ trInfo.tr.LazyLog(stringer("OK"), false)
}
opts := &transport.Options{
Last: true,
@@ -375,8 +385,8 @@
}
return err
}
- if traceInfo.tr != nil {
- traceInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
+ if trInfo != nil {
+ trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
return t.WriteStatus(stream, statusCode, statusDesc)
default:
@@ -385,30 +395,24 @@
}
}
-func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc) (err error) {
+func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
ss := &serverStream{
- t: t,
- s: stream,
- p: &parser{s: stream},
- codec: s.opts.codec,
- tracing: EnableTracing,
+ t: t,
+ s: stream,
+ p: &parser{s: stream},
+ codec: s.opts.codec,
+ trInfo: trInfo,
}
- if ss.tracing {
- ss.traceInfo.tr = stream.Trace()
- ss.traceInfo.firstLine.client = false
- ss.traceInfo.firstLine.remoteAddr = t.RemoteAddr()
- if dl, ok := stream.Context().Deadline(); ok {
- ss.traceInfo.firstLine.deadline = dl.Sub(time.Now())
- }
- ss.traceInfo.tr.LazyLog(&ss.traceInfo.firstLine, false)
+ if trInfo != nil {
+ trInfo.tr.LazyLog(&trInfo.firstLine, false)
defer func() {
ss.mu.Lock()
if err != nil && err != io.EOF {
- ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
- ss.traceInfo.tr.SetError()
+ trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
+ trInfo.tr.SetError()
}
- ss.traceInfo.tr.Finish()
- ss.traceInfo.tr = nil
+ trInfo.tr.Finish()
+ trInfo.tr = nil
ss.mu.Unlock()
}()
}
@@ -421,13 +425,13 @@
ss.statusDesc = appErr.Error()
}
}
- if ss.tracing {
+ if trInfo != nil {
ss.mu.Lock()
if ss.statusCode != codes.OK {
- ss.traceInfo.tr.LazyLog(stringer(ss.statusDesc), true)
- ss.traceInfo.tr.SetError()
+ trInfo.tr.LazyLog(stringer(ss.statusDesc), true)
+ trInfo.tr.SetError()
} else {
- ss.traceInfo.tr.LazyLog(stringer("OK"), false)
+ trInfo.tr.LazyLog(stringer("OK"), false)
}
ss.mu.Unlock()
}
@@ -435,7 +439,7 @@
}
-func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
+func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
sm := stream.Method()
if sm != "" && sm[0] == '/' {
sm = sm[1:]
@@ -458,11 +462,11 @@
}
// Unary RPC or Streaming RPC?
if md, ok := srv.md[method]; ok {
- s.processUnaryRPC(t, stream, srv, md)
+ s.processUnaryRPC(t, stream, srv, md, trInfo)
return
}
if sd, ok := srv.sd[method]; ok {
- s.processStreamingRPC(t, stream, srv, sd)
+ s.processStreamingRPC(t, stream, srv, sd, trInfo)
return
}
if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown method %v", method)); err != nil {
diff --git a/stream.go b/stream.go
index 28938a3..21c51e8 100644
--- a/stream.go
+++ b/stream.go
@@ -126,13 +126,13 @@
tracing: EnableTracing,
}
if cs.tracing {
- cs.traceInfo.tr = trace.New("grpc.Sent."+transport.MethodFamily(method), method)
- cs.traceInfo.firstLine.client = true
+ cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
+ cs.trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
- cs.traceInfo.firstLine.deadline = deadline.Sub(time.Now())
+ cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
}
- cs.traceInfo.tr.LazyLog(&cs.traceInfo.firstLine, false)
- ctx = trace.NewContext(ctx, cs.traceInfo.tr)
+ cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
+ ctx = trace.NewContext(ctx, cs.trInfo.tr)
}
s, err := t.NewStream(ctx, callHdr)
if err != nil {
@@ -154,10 +154,10 @@
tracing bool // set to EnableTracing when the clientStream is created.
- mu sync.Mutex // protects traceInfo
- // traceInfo.tr is set when the clientStream is created (if EnableTracing is true),
+ mu sync.Mutex // protects trInfo.tr
+ // trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
- traceInfo traceInfo
+ trInfo traceInfo
}
func (cs *clientStream) Context() context.Context {
@@ -181,8 +181,8 @@
func (cs *clientStream) SendMsg(m interface{}) (err error) {
if cs.tracing {
cs.mu.Lock()
- if cs.traceInfo.tr != nil {
- cs.traceInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
+ if cs.trInfo.tr != nil {
+ cs.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
}
cs.mu.Unlock()
}
@@ -213,8 +213,8 @@
if err == nil {
if cs.tracing {
cs.mu.Lock()
- if cs.traceInfo.tr != nil {
- cs.traceInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
+ if cs.trInfo.tr != nil {
+ cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
}
cs.mu.Unlock()
}
@@ -266,15 +266,15 @@
}
cs.mu.Lock()
defer cs.mu.Unlock()
- if cs.traceInfo.tr != nil {
+ if cs.trInfo.tr != nil {
if err == nil || err == io.EOF {
- cs.traceInfo.tr.LazyPrintf("RPC: [OK]")
+ cs.trInfo.tr.LazyPrintf("RPC: [OK]")
} else {
- cs.traceInfo.tr.LazyPrintf("RPC: [%v]", err)
- cs.traceInfo.tr.SetError()
+ cs.trInfo.tr.LazyPrintf("RPC: [%v]", err)
+ cs.trInfo.tr.SetError()
}
- cs.traceInfo.tr.Finish()
- cs.traceInfo.tr = nil
+ cs.trInfo.tr.Finish()
+ cs.trInfo.tr = nil
}
}
@@ -298,13 +298,9 @@
codec Codec
statusCode codes.Code
statusDesc string
+ trInfo *traceInfo
- tracing bool // set to EnableTracing when the serverStream is created.
-
- mu sync.Mutex // protects traceInfo
- // traceInfo.tr is set when the serverStream is created (if EnableTracing is true),
- // and is set to nil when the serverStream's finish method is called.
- traceInfo traceInfo
+ mu sync.Mutex // protects trInfo.tr after the service handler runs.
}
func (ss *serverStream) Context() context.Context {
@@ -325,13 +321,15 @@
func (ss *serverStream) SendMsg(m interface{}) (err error) {
defer func() {
- if ss.tracing {
+ if ss.trInfo != nil {
ss.mu.Lock()
- if err == nil {
- ss.traceInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
- } else {
- ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
- ss.traceInfo.tr.SetError()
+ if ss.trInfo.tr != nil {
+ if err == nil {
+ ss.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
+ } else {
+ ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
+ ss.trInfo.tr.SetError()
+ }
}
ss.mu.Unlock()
}
@@ -346,13 +344,15 @@
func (ss *serverStream) RecvMsg(m interface{}) (err error) {
defer func() {
- if ss.tracing {
+ if ss.trInfo != nil {
ss.mu.Lock()
- if err == nil {
- ss.traceInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
- } else if err != io.EOF {
- ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
- ss.traceInfo.tr.SetError()
+ if ss.trInfo.tr != nil {
+ if err == nil {
+ ss.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
+ } else if err != io.EOF {
+ ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
+ ss.trInfo.tr.SetError()
+ }
}
ss.mu.Unlock()
}
diff --git a/trace.go b/trace.go
index 9b88444..cde04fb 100644
--- a/trace.go
+++ b/trace.go
@@ -38,6 +38,7 @@
"fmt"
"io"
"net"
+ "strings"
"time"
"golang.org/x/net/trace"
@@ -47,6 +48,19 @@
// This should only be set before any RPCs are sent or received by this program.
var EnableTracing = true
+// methodFamily returns the trace family for the given method.
+// It turns "/pkg.Service/GetFoo" into "pkg.Service".
+func methodFamily(m string) string {
+ m = strings.TrimPrefix(m, "/") // remove leading slash
+ if i := strings.Index(m, "/"); i >= 0 {
+ m = m[:i] // remove everything from second slash
+ }
+ if i := strings.LastIndex(m, "."); i >= 0 {
+ m = m[i+1:] // cut down to last dotted component
+ }
+ return m
+}
+
// traceInfo contains tracing information for an RPC.
type traceInfo struct {
tr trace.Trace
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 89371c2..ed8fde0 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -45,7 +45,6 @@
"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
- "golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
@@ -81,8 +80,6 @@
fc *inFlow
// sendQuotaPool provides flow control to outbound message.
sendQuotaPool *quotaPool
- // tracing indicates whether tracing is on for this http2Server transport.
- tracing bool
mu sync.Mutex // guard the following
state transportState
@@ -93,7 +90,7 @@
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
// returned if something goes wrong.
-func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo, tracing bool) (_ ServerTransport, err error) {
+func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (_ ServerTransport, err error) {
framer := newFramer(conn)
// Send initial settings as connection preface to client.
var settings []http2.Setting
@@ -118,16 +115,15 @@
}
var buf bytes.Buffer
t := &http2Server{
- conn: conn,
- authInfo: authInfo,
- framer: framer,
- hBuf: &buf,
- hEnc: hpack.NewEncoder(&buf),
- maxStreams: maxStreams,
- controlBuf: newRecvBuffer(),
- fc: &inFlow{limit: initialConnWindowSize},
- sendQuotaPool: newQuotaPool(defaultWindowSize),
- tracing: tracing,
+ conn: conn,
+ authInfo: authInfo,
+ framer: framer,
+ hBuf: &buf,
+ hEnc: hpack.NewEncoder(&buf),
+ maxStreams: maxStreams,
+ controlBuf: newRecvBuffer(),
+ fc: &inFlow{limit: initialConnWindowSize},
+ sendQuotaPool: newQuotaPool(defaultWindowSize),
state: reachable,
writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
@@ -142,7 +138,7 @@
// operateHeader takes action on the decoded headers. It returns the current
// stream if there are remaining headers on the wire (in the following
// Continuation frame).
-func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream), wg *sync.WaitGroup) (pendingStream *Stream) {
+func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) {
defer func() {
if pendingStream == nil {
hDec.state = decodeState{}
@@ -206,21 +202,13 @@
recv: s.buf,
}
s.method = hDec.state.method
- if t.tracing {
- s.tr = trace.New("grpc.Recv."+MethodFamily(s.method), s.method)
- s.ctx = trace.NewContext(s.ctx, s.tr)
- }
- wg.Add(1)
- go func() {
- handle(s)
- wg.Done()
- }()
+ handle(s, wg)
return nil
}
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
-func (t *http2Server) HandleStreams(handle func(*Stream)) {
+func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
// Check the validity of client preface.
preface := make([]byte, len(clientPreface))
if _, err := io.ReadFull(t.conn, preface); err != nil {
diff --git a/transport/transport.go b/transport/transport.go
index 93efeae..c319a5f 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -43,7 +43,6 @@
"fmt"
"io"
"net"
- "strings"
"sync"
"time"
@@ -54,19 +53,6 @@
"google.golang.org/grpc/metadata"
)
-// MethodFamily returns the trace family for the given method.
-// It turns "/pkg.Service/GetFoo" into "pkg.Service".
-func MethodFamily(m string) string {
- m = strings.TrimPrefix(m, "/") // remove leading slash
- if i := strings.Index(m, "/"); i >= 0 {
- m = m[:i] // remove everything from second slash
- }
- if i := strings.LastIndex(m, "."); i >= 0 {
- m = m[i+1:] // cut down to last dotted component
- }
- return m
-}
-
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
@@ -213,8 +199,6 @@
// the status received from the server.
statusCode codes.Code
statusDesc string
- // tracing information
- tr trace.Trace
}
// Header acquires the key-value pairs of header metadata once it
@@ -249,9 +233,9 @@
return s.ctx
}
-// Trace returns the trace.Trace of the stream.
-func (s *Stream) Trace() trace.Trace {
- return s.tr
+// TraceContext recreates the context of s with a trace.Trace.
+func (s *Stream) TraceContext(tr trace.Trace) {
+ s.ctx = trace.NewContext(s.ctx, tr)
}
// Method returns the method for the stream.
@@ -330,8 +314,8 @@
// NewServerTransport creates a ServerTransport with conn or non-nil error
// if it fails.
-func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo, tracing bool) (ServerTransport, error) {
- return newHTTP2Server(conn, maxStreams, authInfo, tracing)
+func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (ServerTransport, error) {
+ return newHTTP2Server(conn, maxStreams, authInfo)
}
// ConnectOptions covers all relevant options for dialing a server.
@@ -407,7 +391,7 @@
// WriteHeader sends the header metedata for the given stream.
WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler.
- HandleStreams(func(*Stream))
+ HandleStreams(func(*Stream, *sync.WaitGroup))
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 70d345a..ba1d66a 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -77,7 +77,8 @@
misbehaved
)
-func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
+func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) {
+ defer wg.Done()
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
@@ -99,11 +100,16 @@
}
// handleStreamSuspension blocks until s.ctx is canceled.
-func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
- <-s.ctx.Done()
+func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) {
+ wg.Add(1)
+ go func() {
+ <-s.ctx.Done()
+ wg.Done()
+ }()
}
-func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
+func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) {
+ defer wg.Done()
conn, ok := s.ServerTransport().(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
@@ -150,7 +156,7 @@
if err != nil {
return
}
- transport, err := NewServerTransport("http2", conn, maxStreams, nil, false)
+ transport, err := NewServerTransport("http2", conn, maxStreams, nil)
if err != nil {
return
}
@@ -167,12 +173,14 @@
case suspended:
go transport.HandleStreams(h.handleStreamSuspension)
case misbehaved:
- go transport.HandleStreams(func(s *Stream) {
- h.handleStreamMisbehave(t, s)
+ go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
+ wg.Add(1)
+ go h.handleStreamMisbehave(t, s, wg)
})
default:
- go transport.HandleStreams(func(s *Stream) {
- h.handleStream(t, s)
+ go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
+ wg.Add(1)
+ go h.handleStream(t, s, wg)
})
}
}