internal/lsp: rewrite the stats using the newer telemetry features

This allows us to reduce the handler interface and delete the telemetry handler.
It is also safer and faster, and can be easily disabled along with the rest of
the telemetry system.

Change-Id: Ia4961d7f2e374f7dc22360d6a4020a065bfeae6f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/225957
Run-TryBot: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
diff --git a/internal/jsonrpc2/handler.go b/internal/jsonrpc2/handler.go
index 819e652..8351579 100644
--- a/internal/jsonrpc2/handler.go
+++ b/internal/jsonrpc2/handler.go
@@ -30,23 +30,6 @@
 
 	// Request is called near the start of processing any request.
 	Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context
-	// Response is called near the start of processing any response.
-	Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context
-	// Done is called when any request is fully processed.
-	// For calls, this means the response has also been processed, for notifies
-	// this is as soon as the message has been written to the stream.
-	// If err is set, it implies the request failed.
-	Done(ctx context.Context, err error)
-	// Read is called with a count each time some data is read from the stream.
-	// The read calls are delayed until after the data has been interpreted so
-	// that it can be attributed to a request/response.
-	Read(ctx context.Context, bytes int64) context.Context
-	// Wrote is called each time some data is written to the stream.
-	Wrote(ctx context.Context, bytes int64) context.Context
-	// Error is called with errors that cannot be delivered through the normal
-	// mechanisms, for instance a failure to process a notify cannot be delivered
-	// back to the other party.
-	Error(ctx context.Context, err error)
 }
 
 // Direction is used to indicate to a logger whether the logged message was being
@@ -89,19 +72,6 @@
 	return ctx
 }
 
-func (EmptyHandler) Done(ctx context.Context, err error) {
-}
-
-func (EmptyHandler) Read(ctx context.Context, bytes int64) context.Context {
-	return ctx
-}
-
-func (EmptyHandler) Wrote(ctx context.Context, bytes int64) context.Context {
-	return ctx
-}
-
-func (EmptyHandler) Error(ctx context.Context, err error) {}
-
 type defaultHandler struct{ EmptyHandler }
 
 func (defaultHandler) Deliver(ctx context.Context, r *Request, delivered bool) bool {
diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go
index 876e4ed..c7724ad 100644
--- a/internal/jsonrpc2/jsonrpc2.go
+++ b/internal/jsonrpc2/jsonrpc2.go
@@ -13,6 +13,9 @@
 	"fmt"
 	"sync"
 	"sync/atomic"
+
+	"golang.org/x/tools/internal/lsp/debug/tag"
+	"golang.org/x/tools/internal/telemetry/event"
 )
 
 // Conn is a JSON RPC 2 client server connection.
@@ -111,15 +114,19 @@
 	for _, h := range c.handlers {
 		ctx = h.Request(ctx, c, Send, request)
 	}
+	ctx, done := event.StartSpan(ctx, request.Method,
+		tag.Method.Of(request.Method),
+		tag.RPCDirection.Of(tag.Outbound),
+		tag.RPCID.Of(request.ID.String()),
+	)
 	defer func() {
-		for _, h := range c.handlers {
-			h.Done(ctx, err)
-		}
+		recordStatus(ctx, err)
+		done()
 	}()
+
+	event.Record(ctx, tag.Started.Of(1))
 	n, err := c.stream.Write(ctx, data)
-	for _, h := range c.handlers {
-		ctx = h.Wrote(ctx, n)
-	}
+	event.Record(ctx, tag.ReceivedBytes.Of(n))
 	return err
 }
 
@@ -146,6 +153,16 @@
 	for _, h := range c.handlers {
 		ctx = h.Request(ctx, c, Send, request)
 	}
+	ctx, done := event.StartSpan(ctx, request.Method,
+		tag.Method.Of(request.Method),
+		tag.RPCDirection.Of(tag.Outbound),
+		tag.RPCID.Of(request.ID.String()),
+	)
+	defer func() {
+		recordStatus(ctx, err)
+		done()
+	}()
+	event.Record(ctx, tag.Started.Of(1))
 	// We have to add ourselves to the pending map before we send, otherwise we
 	// are racing the response. Also add a buffer to rchan, so that if we get a
 	// wire response between the time this call is cancelled and id is deleted
@@ -158,15 +175,10 @@
 		c.pendingMu.Lock()
 		delete(c.pending, id)
 		c.pendingMu.Unlock()
-		for _, h := range c.handlers {
-			h.Done(ctx, err)
-		}
 	}()
 	// now we are ready to send
 	n, err := c.stream.Write(ctx, data)
-	for _, h := range c.handlers {
-		ctx = h.Wrote(ctx, n)
-	}
+	event.Record(ctx, tag.ReceivedBytes.Of(n))
 	if err != nil {
 		// sending failed, we will never get a response, so don't leave it pending
 		return err
@@ -174,9 +186,6 @@
 	// now wait for the response
 	select {
 	case response := <-rchan:
-		for _, h := range c.handlers {
-			ctx = h.Response(ctx, c, Receive, response)
-		}
 		// is it an error response?
 		if response.Error != nil {
 			return response.Error
@@ -261,13 +270,8 @@
 	if err != nil {
 		return err
 	}
-	for _, h := range r.conn.handlers {
-		ctx = h.Response(ctx, r.conn, Send, response)
-	}
 	n, err := r.conn.stream.Write(ctx, data)
-	for _, h := range r.conn.handlers {
-		ctx = h.Wrote(ctx, n)
-	}
+	event.Record(ctx, tag.ReceivedBytes.Of(n))
 
 	if err != nil {
 		// TODO(iancottrell): if a stream write fails, we really need to shut down
@@ -324,9 +328,6 @@
 		if err := json.Unmarshal(data, msg); err != nil {
 			// a badly formed message arrived, log it and continue
 			// we trust the stream to have isolated the error to just this message
-			for _, h := range c.handlers {
-				h.Error(runCtx, fmt.Errorf("unmarshal failed: %v", err))
-			}
 			continue
 		}
 		// Work out whether this is a request or response.
@@ -349,11 +350,20 @@
 			}
 			for _, h := range c.handlers {
 				reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest)
-				reqCtx = h.Read(reqCtx, n)
 			}
+			reqCtx, done := event.StartSpan(reqCtx, req.WireRequest.Method,
+				tag.Method.Of(req.WireRequest.Method),
+				tag.RPCDirection.Of(tag.Inbound),
+				tag.RPCID.Of(req.WireRequest.ID.String()),
+			)
+			event.Record(reqCtx,
+				tag.Started.Of(1),
+				tag.SentBytes.Of(n))
 			c.setHandling(req, true)
+			_, queueDone := event.StartSpan(reqCtx, "queued")
 			go func() {
 				<-thisRequest
+				queueDone()
 				req.state = requestSerial
 				defer func() {
 					c.setHandling(req, false)
@@ -361,9 +371,8 @@
 						req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
 					}
 					req.Parallel()
-					for _, h := range c.handlers {
-						h.Done(reqCtx, err)
-					}
+					recordStatus(reqCtx, nil)
+					done()
 					cancelReq()
 				}()
 				delivered := false
@@ -388,9 +397,6 @@
 				rchan <- response
 			}
 		default:
-			for _, h := range c.handlers {
-				h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring"))
-			}
 		}
 	}
 }
@@ -403,3 +409,11 @@
 	raw := json.RawMessage(data)
 	return &raw, nil
 }
+
+func recordStatus(ctx context.Context, err error) {
+	if err != nil {
+		event.Label(ctx, tag.StatusCode.Of("ERROR"))
+	} else {
+		event.Label(ctx, tag.StatusCode.Of("OK"))
+	}
+}
diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go
index 861de10..05fdbb6 100644
--- a/internal/lsp/cmd/serve.go
+++ b/internal/lsp/cmd/serve.go
@@ -84,13 +84,12 @@
 	if s.app.Remote != "" {
 		network, addr := parseAddr(s.app.Remote)
 		ss = lsprpc.NewForwarder(network, addr,
-			lsprpc.WithTelemetry(true),
 			lsprpc.RemoteDebugAddress(s.RemoteDebug),
 			lsprpc.RemoteListenTimeout(s.RemoteListenTimeout),
 			lsprpc.RemoteLogfile(s.RemoteLogfile),
 		)
 	} else {
-		ss = lsprpc.NewStreamServer(cache.New(ctx, s.app.options), lsprpc.WithTelemetry(true))
+		ss = lsprpc.NewStreamServer(cache.New(ctx, s.app.options))
 	}
 
 	if s.Address != "" {
diff --git a/internal/lsp/debug/rpc.go b/internal/lsp/debug/rpc.go
index 0f39c7c..823ee9d 100644
--- a/internal/lsp/debug/rpc.go
+++ b/internal/lsp/debug/rpc.go
@@ -11,9 +11,11 @@
 	"net/http"
 	"sort"
 	"sync"
+	"time"
 
 	"golang.org/x/tools/internal/lsp/debug/tag"
 	"golang.org/x/tools/internal/telemetry/event"
+	"golang.org/x/tools/internal/telemetry/export"
 	"golang.org/x/tools/internal/telemetry/export/metric"
 )
 
@@ -92,11 +94,33 @@
 }
 
 func (r *rpcs) ProcessEvent(ctx context.Context, ev event.Event, tagMap event.TagMap) context.Context {
-	if !ev.IsRecord() {
+	switch {
+	case ev.IsEndSpan():
+		// calculate latency if this was an rpc span
+		span := export.GetSpan(ctx)
+		if span == nil {
+			return ctx
+		}
+		// is this a finished rpc span, if so it will have a status code record
+		for _, ev := range span.Events() {
+			code := tag.StatusCode.Get(ev.Map())
+			if code != "" {
+				elapsedTime := span.Finish().At.Sub(span.Start().At)
+				latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
+				statsCtx := event.Label1(ctx, tag.StatusCode.Of(code))
+				event.Record1(statsCtx, tag.Latency.Of(latencyMillis))
+			}
+		}
+		return ctx
+	case ev.IsRecord():
+		// fall through to the metrics handling logic
+	default:
+		// ignore all other event types
 		return ctx
 	}
 	r.mu.Lock()
 	defer r.mu.Unlock()
+	//TODO(38168): we should just deal with the events here and not use metrics
 	metrics := metric.Entries.Get(tagMap).([]metric.Data)
 	for _, data := range metrics {
 		for i, group := range data.Groups() {
diff --git a/internal/lsp/lsprpc/lsprpc.go b/internal/lsp/lsprpc/lsprpc.go
index 5502126..37003f1 100644
--- a/internal/lsp/lsprpc/lsprpc.go
+++ b/internal/lsp/lsprpc/lsprpc.go
@@ -36,37 +36,17 @@
 // The StreamServer type is a jsonrpc2.StreamServer that handles incoming
 // streams as a new LSP session, using a shared cache.
 type StreamServer struct {
-	withTelemetry bool
-	cache         *cache.Cache
+	cache *cache.Cache
 
 	// serverForTest may be set to a test fake for testing.
 	serverForTest protocol.Server
 }
 
-// A ServerOption configures the behavior of the LSP server.
-type ServerOption interface {
-	setServer(*StreamServer)
-}
-
-// WithTelemetry configures either a Server or Forwarder to instrument RPCs
-// with additional telemetry.
-type WithTelemetry bool
-
-func (t WithTelemetry) setServer(s *StreamServer) {
-	s.withTelemetry = bool(t)
-}
-
 // NewStreamServer creates a StreamServer using the shared cache. If
 // withTelemetry is true, each session is instrumented with telemetry that
 // records RPC statistics.
-func NewStreamServer(cache *cache.Cache, opts ...ServerOption) *StreamServer {
-	s := &StreamServer{
-		cache: cache,
-	}
-	for _, opt := range opts {
-		opt.setServer(s)
-	}
-	return s
+func NewStreamServer(cache *cache.Cache) *StreamServer {
+	return &StreamServer{cache: cache}
 }
 
 // debugInstance is the common functionality shared between client and server
@@ -156,9 +136,6 @@
 	}()
 	conn.AddHandler(protocol.ServerHandler(server))
 	conn.AddHandler(protocol.Canceller{})
-	if s.withTelemetry {
-		conn.AddHandler(telemetryHandler{})
-	}
 	executable, err := os.Executable()
 	if err != nil {
 		log.Printf("error getting gopls path: %v", err)
@@ -184,7 +161,6 @@
 	goplsPath string
 
 	// configuration
-	withTelemetry       bool
 	dialTimeout         time.Duration
 	retries             int
 	remoteDebug         string
@@ -197,10 +173,6 @@
 	setForwarder(*Forwarder)
 }
 
-func (t WithTelemetry) setForwarder(fwd *Forwarder) {
-	fwd.withTelemetry = bool(t)
-}
-
 // RemoteDebugAddress configures the address used by the auto-started Gopls daemon
 // for serving debug information.
 type RemoteDebugAddress string
@@ -290,9 +262,6 @@
 	clientConn.AddHandler(protocol.ServerHandler(server))
 	clientConn.AddHandler(protocol.Canceller{})
 	clientConn.AddHandler(forwarderHandler{})
-	if f.withTelemetry {
-		clientConn.AddHandler(telemetryHandler{})
-	}
 	g, ctx := errgroup.WithContext(ctx)
 	g.Go(func() error {
 		return serverConn.Run(ctx)
diff --git a/internal/lsp/lsprpc/telemetry.go b/internal/lsp/lsprpc/telemetry.go
deleted file mode 100644
index 86ca43d..0000000
--- a/internal/lsp/lsprpc/telemetry.go
+++ /dev/null
@@ -1,114 +0,0 @@
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package lsprpc
-
-import (
-	"context"
-	"encoding/json"
-	"time"
-
-	"golang.org/x/tools/internal/jsonrpc2"
-	"golang.org/x/tools/internal/lsp/debug/tag"
-	"golang.org/x/tools/internal/telemetry/event"
-)
-
-type telemetryHandler struct{}
-
-func (h telemetryHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
-	stats := h.getStats(ctx)
-	if stats != nil {
-		stats.delivering()
-	}
-	return false
-}
-
-func (h telemetryHandler) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool {
-	return false
-}
-
-func (h telemetryHandler) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
-	if r.Method == "" {
-		panic("no method in rpc stats")
-	}
-	stats := &rpcStats{
-		method:    r.Method,
-		id:        r.ID,
-		start:     time.Now(),
-		direction: direction,
-		payload:   r.Params,
-	}
-	ctx = context.WithValue(ctx, statsKey, stats)
-	mode := tag.Outbound
-	if direction == jsonrpc2.Receive {
-		mode = tag.Inbound
-	}
-	ctx, stats.close = event.StartSpan(ctx, r.Method,
-		tag.Method.Of(r.Method),
-		tag.RPCDirection.Of(mode),
-		tag.RPCID.Of(r.ID.String()),
-	)
-	event.Record(ctx, tag.Started.Of(1))
-	_, stats.delivering = event.StartSpan(ctx, "queued")
-	return ctx
-}
-
-func (h telemetryHandler) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
-	return ctx
-}
-
-func (h telemetryHandler) Done(ctx context.Context, err error) {
-	stats := h.getStats(ctx)
-	if err != nil {
-		ctx = event.Label(ctx, tag.StatusCode.Of("ERROR"))
-	} else {
-		ctx = event.Label(ctx, tag.StatusCode.Of("OK"))
-	}
-	elapsedTime := time.Since(stats.start)
-	latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
-	event.Record(ctx, tag.Latency.Of(latencyMillis))
-	stats.close()
-}
-
-func (h telemetryHandler) Read(ctx context.Context, bytes int64) context.Context {
-	event.Record(ctx, tag.SentBytes.Of(bytes))
-	return ctx
-}
-
-func (h telemetryHandler) Wrote(ctx context.Context, bytes int64) context.Context {
-	event.Record(ctx, tag.ReceivedBytes.Of(bytes))
-	return ctx
-}
-
-func (h telemetryHandler) Error(ctx context.Context, err error) {
-}
-
-func (h telemetryHandler) getStats(ctx context.Context) *rpcStats {
-	stats, ok := ctx.Value(statsKey).(*rpcStats)
-	if !ok || stats == nil {
-		method, ok := ctx.Value(tag.Method).(string)
-		if !ok {
-			method = "???"
-		}
-		stats = &rpcStats{
-			method: method,
-			close:  func() {},
-		}
-	}
-	return stats
-}
-
-type rpcStats struct {
-	method     string
-	direction  jsonrpc2.Direction
-	id         *jsonrpc2.ID
-	payload    *json.RawMessage
-	start      time.Time
-	delivering func()
-	close      func()
-}
-
-type statsKeyType int
-
-const statsKey = statsKeyType(0)