internal/lsp: fix the incomplete and broken commit

https://go-review.googlesource.com/c/tools/+/186297 was the wrong commit, this
adds the changes that were supposed to be part of it.

Change-Id: I0c4783195c2670f89c3213dce2511d98f21f1cf4
Reviewed-on: https://go-review.googlesource.com/c/tools/+/186379
Run-TryBot: Ian Cottrell <iancottrell@google.com>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go
index 76919a8..6eb6e9e 100644
--- a/internal/jsonrpc2/jsonrpc2.go
+++ b/internal/jsonrpc2/jsonrpc2.go
@@ -62,7 +62,7 @@
 // You must call Run for the connection to be active.
 func NewConn(s Stream) *Conn {
 	conn := &Conn{
-		handlers: []Handler{defaultHandler{}, &tracer{}},
+		handlers: []Handler{defaultHandler{}},
 		stream:   s,
 		pending:  make(map[ID]chan *WireResponse),
 		handling: make(map[ID]*Request),
diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go
index 275b002..5e7d46d 100644
--- a/internal/lsp/cmd/serve.go
+++ b/internal/lsp/cmd/serve.go
@@ -83,7 +83,7 @@
 
 	// For debugging purposes only.
 	run := func(ctx context.Context, srv *lsp.Server) {
-		srv.Conn.AddHandler(&handler{trace: s.Trace, out: out})
+		srv.Conn.AddHandler(&handler{loggingRPCs: s.Trace, out: out})
 		go srv.Run(ctx)
 	}
 	if s.Address != "" {
@@ -94,7 +94,7 @@
 	}
 	stream := jsonrpc2.NewHeaderStream(os.Stdin, os.Stdout)
 	ctx, srv := lsp.NewServer(ctx, s.app.cache, stream)
-	srv.Conn.AddHandler(&handler{trace: s.Trace, out: out})
+	srv.Conn.AddHandler(&handler{loggingRPCs: s.Trace, out: out})
 	return srv.Run(ctx)
 }
 
@@ -119,8 +119,8 @@
 }
 
 type handler struct {
-	trace bool
-	out   io.Writer
+	loggingRPCs bool
+	out         io.Writer
 }
 
 type rpcStats struct {
@@ -129,6 +129,7 @@
 	id        *jsonrpc2.ID
 	payload   *json.RawMessage
 	start     time.Time
+	close     func()
 }
 
 type statsKeyType int
@@ -144,49 +145,63 @@
 }
 
 func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
-	if !h.trace {
-		return ctx
+	if r.Method == "" {
+		panic("no method in rpc stats")
 	}
-	stats := &rpcStats{
+	s := &rpcStats{
 		method:    r.Method,
-		direction: direction,
 		start:     time.Now(),
+		direction: direction,
 		payload:   r.Params,
 	}
-	ctx = context.WithValue(ctx, statsKey, stats)
+	mode := telemetry.Outbound
+	if direction == jsonrpc2.Receive {
+		mode = telemetry.Inbound
+	}
+	ctx, s.close = trace.StartSpan(ctx, r.Method,
+		tag.Tag{Key: telemetry.Method, Value: r.Method},
+		tag.Tag{Key: telemetry.RPCDirection, Value: mode},
+		tag.Tag{Key: telemetry.RPCID, Value: r.ID},
+	)
+	telemetry.Started.Record(ctx, 1)
 	return ctx
 }
 
 func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
 	stats := h.getStats(ctx)
-	h.log(direction, r.ID, 0, stats.method, r.Result, nil)
+	h.logRPC(direction, r.ID, 0, stats.method, r.Result, nil)
 	return ctx
 }
 
 func (h *handler) Done(ctx context.Context, err error) {
-	if !h.trace {
-		return
-	}
 	stats := h.getStats(ctx)
-	h.log(stats.direction, stats.id, time.Since(stats.start), stats.method, stats.payload, err)
+	h.logRPC(stats.direction, stats.id, time.Since(stats.start), stats.method, stats.payload, err)
+	if err != nil {
+		ctx = telemetry.StatusCode.With(ctx, "ERROR")
+	} else {
+		ctx = telemetry.StatusCode.With(ctx, "OK")
+	}
+	elapsedTime := time.Since(stats.start)
+	latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
+	telemetry.Latency.Record(ctx, latencyMillis)
+	stats.close()
 }
 
 func (h *handler) Read(ctx context.Context, bytes int64) context.Context {
+	telemetry.SentBytes.Record(ctx, bytes)
 	return ctx
 }
 
 func (h *handler) Wrote(ctx context.Context, bytes int64) context.Context {
+	telemetry.ReceivedBytes.Record(ctx, bytes)
 	return ctx
 }
 
 const eol = "\r\n\r\n\r\n"
 
 func (h *handler) Error(ctx context.Context, err error) {
-	if !h.trace {
-		return
-	}
 	stats := h.getStats(ctx)
-	h.log(stats.direction, stats.id, 0, stats.method, nil, err)
+	h.logRPC(stats.direction, stats.id, 0, stats.method, nil, err)
 }
 
 func (h *handler) getStats(ctx context.Context) *rpcStats {
@@ -194,13 +209,14 @@
 	if !ok || stats == nil {
 		stats = &rpcStats{
 			method: "???",
+			close:  func() {},
 		}
 	}
 	return stats
 }
 
-func (h *handler) log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err error) {
-	if !h.trace {
+func (h *handler) logRPC(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err error) {
+	if !h.loggingRPCs {
 		return
 	}
 	const eol = "\r\n\r\n\r\n"
@@ -249,90 +265,3 @@
 	fmt.Fprintf(outx, ".\r\nParams: %s%s", params, eol)
 	fmt.Fprintf(h.out, "%s", outx.String())
 }
-
-type rpcStats struct {
-	server bool
-	method string
-	close  func()
-	start  time.Time
-}
-
-func start(ctx context.Context, server bool, method string, id *ID) (context.Context, *rpcStats) {
-	if method == "" {
-		panic("no method in rpc stats")
-	}
-	s := &rpcStats{
-		server: server,
-		method: method,
-		start:  time.Now(),
-	}
-	mode := telemetry.Outbound
-	if server {
-		mode = telemetry.Inbound
-	}
-	ctx, s.close = trace.StartSpan(ctx, method,
-		tag.Tag{Key: telemetry.Method, Value: method},
-		tag.Tag{Key: telemetry.RPCDirection, Value: mode},
-		tag.Tag{Key: telemetry.RPCID, Value: id},
-	)
-	telemetry.Started.Record(ctx, 1)
-	return ctx, s
-}
-
-func (s *rpcStats) end(ctx context.Context, err *error) {
-	if err != nil && *err != nil {
-		ctx = telemetry.StatusCode.With(ctx, "ERROR")
-	} else {
-		ctx = telemetry.StatusCode.With(ctx, "OK")
-	}
-	elapsedTime := time.Since(s.start)
-	latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
-	telemetry.Latency.Record(ctx, latencyMillis)
-	s.close()
-}
-
-type statsKeyType int
-
-const statsKey = statsKeyType(0)
-
-type tracer struct {
-}
-
-func (h *tracer) Deliver(ctx context.Context, r *Request, delivered bool) bool {
-	return false
-}
-
-func (h *tracer) Cancel(ctx context.Context, conn *Conn, id ID, cancelled bool) bool {
-	return false
-}
-
-func (h *tracer) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context {
-	ctx, stats := start(ctx, direction == Receive, r.Method, r.ID)
-	ctx = context.WithValue(ctx, statsKey, stats)
-	return ctx
-}
-
-func (h *tracer) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context {
-	return ctx
-}
-
-func (h *tracer) Done(ctx context.Context, err error) {
-	stats, ok := ctx.Value(statsKey).(*rpcStats)
-	if ok && stats != nil {
-		stats.end(ctx, &err)
-	}
-}
-
-func (h *tracer) Read(ctx context.Context, bytes int64) context.Context {
-	telemetry.SentBytes.Record(ctx, bytes)
-	return ctx
-}
-
-func (h *tracer) Wrote(ctx context.Context, bytes int64) context.Context {
-	telemetry.ReceivedBytes.Record(ctx, bytes)
-	return ctx
-}
-
-func (h *tracer) Error(ctx context.Context, err error) {
-	log.Printf("%v", err)
-}