internal/lsp/lsprpc: update binder tests to handle forwarding

Update the new binder tests to run both with a standalone server, and
with a forwarding chain.

Make a few superficial improvements along the way as well.

Change-Id: Icd197698093a3f6149ab58171806b2388ed75b7f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/321134
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/jsonrpc2_v2/conn.go b/internal/jsonrpc2_v2/conn.go
index 7d99a02..606c3f9 100644
--- a/internal/jsonrpc2_v2/conn.go
+++ b/internal/jsonrpc2_v2/conn.go
@@ -112,6 +112,7 @@
 	go c.readIncoming(ctx, reader, readToQueue)
 	go c.manageQueue(ctx, options.Preempter, readToQueue, queueToDeliver)
 	go c.deliverMessages(ctx, options.Handler, queueToDeliver)
+
 	// releaseing the writer must be the last thing we do in case any requests
 	// are blocked waiting for the connection to be ready
 	c.writerBox <- options.Framer.Writer(rwc)
diff --git a/internal/jsonrpc2_v2/jsonrpc2.go b/internal/jsonrpc2_v2/jsonrpc2.go
index faaf205..271f42c 100644
--- a/internal/jsonrpc2_v2/jsonrpc2.go
+++ b/internal/jsonrpc2_v2/jsonrpc2.go
@@ -57,11 +57,11 @@
 	return f(ctx, req)
 }
 
-// async is a small helper for things with an asynchronous result that you can
-// wait for.
+// async is a small helper for operations with an asynchronous result that you
+// can wait for.
 type async struct {
-	ready  chan struct{}
-	errBox chan error
+	ready  chan struct{} // signals that the operation has completed
+	errBox chan error    // guards the operation result
 }
 
 func newAsync() *async {
diff --git a/internal/jsonrpc2_v2/jsonrpc2_test.go b/internal/jsonrpc2_v2/jsonrpc2_test.go
index 6d057b4..1157779 100644
--- a/internal/jsonrpc2_v2/jsonrpc2_test.go
+++ b/internal/jsonrpc2_v2/jsonrpc2_test.go
@@ -126,7 +126,7 @@
 func testConnection(t *testing.T, framer jsonrpc2.Framer) {
 	stacktest.NoLeak(t)
 	ctx := eventtest.NewContext(context.Background(), t)
-	listener, err := jsonrpc2.NetPipe(ctx)
+	listener, err := jsonrpc2.NetPipeListener(ctx)
 	if err != nil {
 		t.Fatal(err)
 	}
diff --git a/internal/jsonrpc2_v2/net.go b/internal/jsonrpc2_v2/net.go
index c8cfaab..0b413d8 100644
--- a/internal/jsonrpc2_v2/net.go
+++ b/internal/jsonrpc2_v2/net.go
@@ -80,11 +80,11 @@
 	return n.dialer.DialContext(ctx, n.network, n.address)
 }
 
-// NetPipe returns a new Listener that listens using net.Pipe.
+// NetPipeListener returns a new Listener that listens using net.Pipe.
 // It is only possibly to connect to it using the Dialier returned by the
 // Dialer method, each call to that method will generate a new pipe the other
 // side of which will be returnd from the Accept call.
-func NetPipe(ctx context.Context) (Listener, error) {
+func NetPipeListener(ctx context.Context) (Listener, error) {
 	return &netPiper{
 		done:   make(chan struct{}),
 		dialed: make(chan io.ReadWriteCloser),
diff --git a/internal/jsonrpc2_v2/serve_test.go b/internal/jsonrpc2_v2/serve_test.go
index 7f1dbc3..26cf6a5 100644
--- a/internal/jsonrpc2_v2/serve_test.go
+++ b/internal/jsonrpc2_v2/serve_test.go
@@ -89,7 +89,7 @@
 			return jsonrpc2.NetListener(ctx, "tcp", "localhost:0", jsonrpc2.NetListenOptions{})
 		}},
 		{"pipe", func(ctx context.Context) (jsonrpc2.Listener, error) {
-			return jsonrpc2.NetPipe(ctx)
+			return jsonrpc2.NetPipeListener(ctx)
 		}},
 	}
 
diff --git a/internal/lsp/lsprpc/binder.go b/internal/lsp/lsprpc/binder.go
index 3f5cb3b..61f82de 100644
--- a/internal/lsp/lsprpc/binder.go
+++ b/internal/lsp/lsprpc/binder.go
@@ -7,9 +7,12 @@
 import (
 	"context"
 	"encoding/json"
+	"fmt"
 
+	"golang.org/x/tools/internal/event"
 	jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
 	"golang.org/x/tools/internal/lsp/protocol"
+	"golang.org/x/tools/internal/xcontext"
 	errors "golang.org/x/xerrors"
 )
 
@@ -87,8 +90,19 @@
 		return opts, err
 	}
 	server := protocol.ServerDispatcherV2(serverConn)
+	preempter := &canceler{
+		conn: conn,
+	}
+	detached := xcontext.Detach(ctx)
+	go func() {
+		conn.Wait()
+		if err := serverConn.Close(); err != nil {
+			event.Log(detached, fmt.Sprintf("closing remote connection: %v", err))
+		}
+	}()
 	return jsonrpc2_v2.ConnectionOptions{
-		Handler: protocol.ServerHandlerV2(server),
+		Handler:   protocol.ServerHandlerV2(server),
+		Preempter: preempter,
 	}, nil
 }
 
diff --git a/internal/lsp/lsprpc/binder_test.go b/internal/lsp/lsprpc/binder_test.go
index d29de0f..5cbdb20 100644
--- a/internal/lsp/lsprpc/binder_test.go
+++ b/internal/lsp/lsprpc/binder_test.go
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// TODO(rFindley): move this to lsprpc_test once it no longer shares with
-//                 lsprpc_test.go.
+// TODO(rFindley): move this to the lsprpc_test package once it no longer
+//                 shares with lsprpc_test.go.
 
 package lsprpc
 
@@ -19,25 +19,41 @@
 )
 
 type testEnv struct {
-	listener  jsonrpc2_v2.Listener
-	conn      *jsonrpc2_v2.Connection
-	rpcServer *jsonrpc2_v2.Server
+	listener jsonrpc2_v2.Listener
+	server   *jsonrpc2_v2.Server
+
+	// non-nil if constructed with forwarded=true
+	fwdListener jsonrpc2_v2.Listener
+	fwdServer   *jsonrpc2_v2.Server
+
+	// the ingoing connection, either to the forwarder or server
+	conn *jsonrpc2_v2.Connection
 }
 
 func (e testEnv) Shutdown(t *testing.T) {
 	if err := e.listener.Close(); err != nil {
 		t.Error(err)
 	}
+	if e.fwdListener != nil {
+		if err := e.fwdListener.Close(); err != nil {
+			t.Error(err)
+		}
+	}
 	if err := e.conn.Close(); err != nil {
 		t.Error(err)
 	}
-	if err := e.rpcServer.Wait(); err != nil {
+	if e.fwdServer != nil {
+		if err := e.fwdServer.Wait(); err != nil {
+			t.Error(err)
+		}
+	}
+	if err := e.server.Wait(); err != nil {
 		t.Error(err)
 	}
 }
 
-func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client) testEnv {
-	listener, err := jsonrpc2_v2.NetPipe(ctx)
+func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client, forwarded bool) testEnv {
+	listener, err := jsonrpc2_v2.NetPipeListener(ctx)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -49,69 +65,102 @@
 	if err != nil {
 		t.Fatal(err)
 	}
+	env := testEnv{
+		listener: listener,
+		server:   rpcServer,
+	}
 	clientBinder := NewClientBinder(func(context.Context, protocol.Server) protocol.Client { return client })
-	conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder)
-	if err != nil {
-		t.Fatal(err)
-	}
-	return testEnv{
-		listener:  listener,
-		rpcServer: rpcServer,
-		conn:      conn,
-	}
-}
-
-func TestClientLoggingV2(t *testing.T) {
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-
-	client := fakeClient{logs: make(chan string, 10)}
-	env := startServing(ctx, t, pingServer{}, client)
-	defer env.Shutdown(t)
-	if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
-		t.Errorf("DidOpen: %v", err)
-	}
-	select {
-	case got := <-client.logs:
-		want := "ping"
-		matched, err := regexp.MatchString(want, got)
+	if forwarded {
+		fwdListener, err := jsonrpc2_v2.NetPipeListener(ctx)
 		if err != nil {
 			t.Fatal(err)
 		}
-		if !matched {
-			t.Errorf("got log %q, want a log containing %q", got, want)
+		fwdBinder := NewForwardBinder(listener.Dialer())
+		fwdServer, err := jsonrpc2_v2.Serve(ctx, fwdListener, fwdBinder)
+		if err != nil {
+			t.Fatal(err)
 		}
-	case <-time.After(1 * time.Second):
-		t.Error("timeout waiting for client log")
+		conn, err := jsonrpc2_v2.Dial(ctx, fwdListener.Dialer(), clientBinder)
+		if err != nil {
+			t.Fatal(err)
+		}
+		env.fwdListener = fwdListener
+		env.fwdServer = fwdServer
+		env.conn = conn
+	} else {
+		conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder)
+		if err != nil {
+			t.Fatal(err)
+		}
+		env.conn = conn
+	}
+	return env
+}
+
+func TestClientLoggingV2(t *testing.T) {
+	ctx := context.Background()
+
+	for name, forwarded := range map[string]bool{
+		"forwarded":  true,
+		"standalone": false,
+	} {
+		t.Run(name, func(t *testing.T) {
+			client := fakeClient{logs: make(chan string, 10)}
+			env := startServing(ctx, t, pingServer{}, client, forwarded)
+			defer env.Shutdown(t)
+			if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
+				t.Errorf("DidOpen: %v", err)
+			}
+			select {
+			case got := <-client.logs:
+				want := "ping"
+				matched, err := regexp.MatchString(want, got)
+				if err != nil {
+					t.Fatal(err)
+				}
+				if !matched {
+					t.Errorf("got log %q, want a log containing %q", got, want)
+				}
+			case <-time.After(1 * time.Second):
+				t.Error("timeout waiting for client log")
+			}
+		})
 	}
 }
 
 func TestRequestCancellationV2(t *testing.T) {
 	ctx := context.Background()
 
-	server := waitableServer{
-		started:   make(chan struct{}),
-		completed: make(chan error),
-	}
-	client := fakeClient{logs: make(chan string, 10)}
-	env := startServing(ctx, t, server, client)
-	defer env.Shutdown(t)
+	for name, forwarded := range map[string]bool{
+		"forwarded":  true,
+		"standalone": false,
+	} {
+		t.Run(name, func(t *testing.T) {
+			server := waitableServer{
+				started:   make(chan struct{}),
+				completed: make(chan error),
+			}
+			client := fakeClient{logs: make(chan string, 10)}
+			env := startServing(ctx, t, server, client, forwarded)
+			defer env.Shutdown(t)
 
-	sd := protocol.ServerDispatcherV2(env.conn)
-	ctx, cancel := context.WithCancel(ctx)
+			sd := protocol.ServerDispatcherV2(env.conn)
+			ctx, cancel := context.WithCancel(ctx)
 
-	result := make(chan error)
-	go func() {
-		_, err := sd.Hover(ctx, &protocol.HoverParams{})
-		result <- err
-	}()
-	// Wait for the Hover request to start.
-	<-server.started
-	cancel()
-	if err := <-result; err == nil {
-		t.Error("nil error for cancelled Hover(), want non-nil")
-	}
-	if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
-		t.Errorf("Hover(): unexpected server-side error %v", err)
+			result := make(chan error)
+			go func() {
+				_, err := sd.Hover(ctx, &protocol.HoverParams{})
+				result <- err
+			}()
+			// Wait for the Hover request to start.
+			<-server.started
+			cancel()
+			if err := <-result; err == nil {
+				t.Error("nil error for cancelled Hover(), want non-nil")
+			}
+			if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
+				t.Errorf("Hover(): unexpected server-side error %v", err)
+			}
+		})
 	}
 }