internal/lsp/lsprpc: add an AutoDialer abstraction

Refactor the lsprpc package to move the logic for 'automatic' server
discovery into an AutoDialer abstraction, which both implements the v2
jsonrpc2 Dialer interface, and provides a dialNet method that can be
used for the existing v1 APIs.

Along the way, simplify the evaluation of remote arguments to eliminate
the overly abstract RemoteOption.

Change-Id: Ic3def17ccc237007a7eb2cc41a12cf058fca9be3
Reviewed-on: https://go-review.googlesource.com/c/tools/+/332490
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go
index 6d0787e..4164b58 100644
--- a/internal/lsp/cmd/serve.go
+++ b/internal/lsp/cmd/serve.go
@@ -56,6 +56,22 @@
 	f.PrintDefaults()
 }
 
+func (s *Serve) remoteArgs(network, address string) []string {
+	args := []string{"serve",
+		"-listen", fmt.Sprintf(`%s;%s`, network, address),
+	}
+	if s.RemoteDebug != "" {
+		args = append(args, "-debug", s.RemoteDebug)
+	}
+	if s.RemoteListenTimeout != 0 {
+		args = append(args, "-listen.timeout", s.RemoteListenTimeout.String())
+	}
+	if s.RemoteLogfile != "" {
+		args = append(args, "-logfile", s.RemoteLogfile)
+	}
+	return args
+}
+
 // Run configures a server based on the flags, and then runs it.
 // It blocks until the server shuts down.
 func (s *Serve) Run(ctx context.Context, args ...string) error {
@@ -77,12 +93,11 @@
 	}
 	var ss jsonrpc2.StreamServer
 	if s.app.Remote != "" {
-		network, addr := lsprpc.ParseAddr(s.app.Remote)
-		ss = lsprpc.NewForwarder(network, addr,
-			lsprpc.RemoteDebugAddress(s.RemoteDebug),
-			lsprpc.RemoteListenTimeout(s.RemoteListenTimeout),
-			lsprpc.RemoteLogfile(s.RemoteLogfile),
-		)
+		var err error
+		ss, err = lsprpc.NewForwarder(s.app.Remote, s.remoteArgs)
+		if err != nil {
+			return errors.Errorf("creating forwarder: %w", err)
+		}
 	} else {
 		ss = lsprpc.NewStreamServer(cache.New(s.app.options), isDaemon)
 	}
diff --git a/internal/lsp/lsprpc/autostart_default.go b/internal/lsp/lsprpc/autostart_default.go
index dc04f66..b23a1e5 100644
--- a/internal/lsp/lsprpc/autostart_default.go
+++ b/internal/lsp/lsprpc/autostart_default.go
@@ -11,13 +11,13 @@
 )
 
 var (
-	startRemote           = startRemoteDefault
+	daemonize             = func(*exec.Cmd) {}
 	autoNetworkAddress    = autoNetworkAddressDefault
 	verifyRemoteOwnership = verifyRemoteOwnershipDefault
 )
 
-func startRemoteDefault(goplsPath string, args ...string) error {
-	cmd := exec.Command(goplsPath, args...)
+func runRemote(cmd *exec.Cmd) error {
+	daemonize(cmd)
 	if err := cmd.Start(); err != nil {
 		return errors.Errorf("starting remote gopls: %w", err)
 	}
diff --git a/internal/lsp/lsprpc/autostart_posix.go b/internal/lsp/lsprpc/autostart_posix.go
index 45089b8..d5644e2 100644
--- a/internal/lsp/lsprpc/autostart_posix.go
+++ b/internal/lsp/lsprpc/autostart_posix.go
@@ -11,7 +11,6 @@
 	"crypto/sha256"
 	"errors"
 	"fmt"
-	exec "golang.org/x/sys/execabs"
 	"log"
 	"os"
 	"os/user"
@@ -19,24 +18,21 @@
 	"strconv"
 	"syscall"
 
+	exec "golang.org/x/sys/execabs"
+
 	"golang.org/x/xerrors"
 )
 
 func init() {
-	startRemote = startRemotePosix
+	daemonize = daemonizePosix
 	autoNetworkAddress = autoNetworkAddressPosix
 	verifyRemoteOwnership = verifyRemoteOwnershipPosix
 }
 
-func startRemotePosix(goplsPath string, args ...string) error {
-	cmd := exec.Command(goplsPath, args...)
+func daemonizePosix(cmd *exec.Cmd) {
 	cmd.SysProcAttr = &syscall.SysProcAttr{
 		Setsid: true,
 	}
-	if err := cmd.Start(); err != nil {
-		return xerrors.Errorf("starting remote gopls: %w", err)
-	}
-	return nil
 }
 
 // autoNetworkAddress resolves an id on the 'auto' pseduo-network to a
diff --git a/internal/lsp/lsprpc/dialer.go b/internal/lsp/lsprpc/dialer.go
new file mode 100644
index 0000000..713307c
--- /dev/null
+++ b/internal/lsp/lsprpc/dialer.go
@@ -0,0 +1,115 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lsprpc
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"time"
+
+	exec "golang.org/x/sys/execabs"
+	"golang.org/x/tools/internal/event"
+	errors "golang.org/x/xerrors"
+)
+
+// AutoNetwork is the pseudo network type used to signal that gopls should use
+// automatic discovery to resolve a remote address.
+const AutoNetwork = "auto"
+
+// An AutoDialer is a jsonrpc2 dialer that understands the 'auto' network.
+type AutoDialer struct {
+	network, addr string // the 'real' network and address
+	isAuto        bool   // whether the server is on the 'auto' network
+
+	executable string
+	argFunc    func(network, addr string) []string
+}
+
+func NewAutoDialer(rawAddr string, argFunc func(network, addr string) []string) (*AutoDialer, error) {
+	d := AutoDialer{
+		argFunc: argFunc,
+	}
+	d.network, d.addr = ParseAddr(rawAddr)
+	if d.network == AutoNetwork {
+		d.isAuto = true
+		bin, err := os.Executable()
+		if err != nil {
+			return nil, errors.Errorf("getting executable: %w", err)
+		}
+		d.executable = bin
+		d.network, d.addr = autoNetworkAddress(bin, d.addr)
+	}
+	return &d, nil
+}
+
+// Dial implements the jsonrpc2.Dialer interface.
+func (d *AutoDialer) Dial(ctx context.Context) (io.ReadWriteCloser, error) {
+	conn, err := d.dialNet(ctx)
+	return conn, err
+}
+
+// TODO(rFindley): remove this once we no longer need to integrate with v1 of
+// the jsonrpc2 package.
+func (d *AutoDialer) dialNet(ctx context.Context) (net.Conn, error) {
+	// Attempt to verify that we own the remote. This is imperfect, but if we can
+	// determine that the remote is owned by a different user, we should fail.
+	ok, err := verifyRemoteOwnership(d.network, d.addr)
+	if err != nil {
+		// If the ownership check itself failed, we fail open but log an error to
+		// the user.
+		event.Error(ctx, "unable to check daemon socket owner, failing open", err)
+	} else if !ok {
+		// We successfully checked that the socket is not owned by us, we fail
+		// closed.
+		return nil, fmt.Errorf("socket %q is owned by a different user", d.addr)
+	}
+	const dialTimeout = 1 * time.Second
+	// Try dialing our remote once, in case it is already running.
+	netConn, err := net.DialTimeout(d.network, d.addr, dialTimeout)
+	if err == nil {
+		return netConn, nil
+	}
+	if d.isAuto && d.argFunc != nil {
+		if d.network == "unix" {
+			// Sometimes the socketfile isn't properly cleaned up when the server
+			// shuts down. Since we have already tried and failed to dial this
+			// address, it should *usually* be safe to remove the socket before
+			// binding to the address.
+			// TODO(rfindley): there is probably a race here if multiple server
+			// instances are simultaneously starting up.
+			if _, err := os.Stat(d.addr); err == nil {
+				if err := os.Remove(d.addr); err != nil {
+					return nil, errors.Errorf("removing remote socket file: %w", err)
+				}
+			}
+		}
+		args := d.argFunc(d.network, d.addr)
+		cmd := exec.Command(d.executable, args...)
+		if err := runRemote(cmd); err != nil {
+			return nil, err
+		}
+	}
+
+	const retries = 5
+	// It can take some time for the newly started server to bind to our address,
+	// so we retry for a bit.
+	for retry := 0; retry < retries; retry++ {
+		startDial := time.Now()
+		netConn, err = net.DialTimeout(d.network, d.addr, dialTimeout)
+		if err == nil {
+			return netConn, nil
+		}
+		event.Log(ctx, fmt.Sprintf("failed attempt #%d to connect to remote: %v\n", retry+2, err))
+		// In case our failure was a fast-failure, ensure we wait at least
+		// f.dialTimeout before trying again.
+		if retry != retries-1 {
+			time.Sleep(dialTimeout - time.Since(startDial))
+		}
+	}
+	return nil, errors.Errorf("dialing remote: %w", err)
+}
diff --git a/internal/lsp/lsprpc/lsprpc.go b/internal/lsp/lsprpc/lsprpc.go
index 18e0299..9177078 100644
--- a/internal/lsp/lsprpc/lsprpc.go
+++ b/internal/lsp/lsprpc/lsprpc.go
@@ -21,7 +21,6 @@
 
 	"golang.org/x/tools/internal/event"
 	"golang.org/x/tools/internal/jsonrpc2"
-	jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
 	"golang.org/x/tools/internal/lsp"
 	"golang.org/x/tools/internal/lsp/cache"
 	"golang.org/x/tools/internal/lsp/command"
@@ -31,10 +30,6 @@
 	errors "golang.org/x/xerrors"
 )
 
-// AutoNetwork is the pseudo network type used to signal that gopls should use
-// automatic discovery to resolve a remote address.
-const AutoNetwork = "auto"
-
 // Unique identifiers for client/server.
 var serverIndex int64
 
@@ -113,13 +108,7 @@
 // be instrumented with telemetry, and want to be able to in some cases hijack
 // the jsonrpc2 connection with the daemon.
 type Forwarder struct {
-	network, addr string
-
-	// goplsPath is the path to the current executing gopls binary.
-	goplsPath string
-
-	// configuration for the auto-started gopls remote.
-	remoteConfig remoteConfig
+	dialer *AutoDialer
 
 	mu sync.Mutex
 	// Hold on to the server connection so that we can redo the handshake if any
@@ -128,68 +117,19 @@
 	serverID   string
 }
 
-type remoteConfig struct {
-	debug         string
-	listenTimeout time.Duration
-	logfile       string
-}
-
-// A RemoteOption configures the behavior of the auto-started remote.
-type RemoteOption interface {
-	set(*remoteConfig)
-}
-
-// RemoteDebugAddress configures the address used by the auto-started Gopls daemon
-// for serving debug information.
-type RemoteDebugAddress string
-
-func (d RemoteDebugAddress) set(cfg *remoteConfig) {
-	cfg.debug = string(d)
-}
-
-// RemoteListenTimeout configures the amount of time the auto-started gopls
-// daemon will wait with no client connections before shutting down.
-type RemoteListenTimeout time.Duration
-
-func (d RemoteListenTimeout) set(cfg *remoteConfig) {
-	cfg.listenTimeout = time.Duration(d)
-}
-
-// RemoteLogfile configures the logfile location for the auto-started gopls
-// daemon.
-type RemoteLogfile string
-
-func (l RemoteLogfile) set(cfg *remoteConfig) {
-	cfg.logfile = string(l)
-}
-
-func defaultRemoteConfig() remoteConfig {
-	return remoteConfig{
-		listenTimeout: 1 * time.Minute,
-	}
-}
-
 // NewForwarder creates a new Forwarder, ready to forward connections to the
-// remote server specified by network and addr.
-func NewForwarder(network, addr string, opts ...RemoteOption) *Forwarder {
-	gp, err := os.Executable()
+// remote server specified by rawAddr. If provided and rawAddr indicates an
+// 'automatic' address (starting with 'auto;'), argFunc may be used to start a
+// remote server for the auto-discovered address.
+func NewForwarder(rawAddr string, argFunc func(network, address string) []string) (*Forwarder, error) {
+	dialer, err := NewAutoDialer(rawAddr, argFunc)
 	if err != nil {
-		log.Printf("error getting gopls path for forwarder: %v", err)
-		gp = ""
+		return nil, err
 	}
-
-	rcfg := defaultRemoteConfig()
-	for _, opt := range opts {
-		opt.set(&rcfg)
-	}
-
 	fwd := &Forwarder{
-		network:      network,
-		addr:         addr,
-		goplsPath:    gp,
-		remoteConfig: rcfg,
+		dialer: dialer,
 	}
-	return fwd
+	return fwd, nil
 }
 
 // QueryServerState queries the server state of the current server.
@@ -247,7 +187,7 @@
 func (f *Forwarder) ServeStream(ctx context.Context, clientConn jsonrpc2.Conn) error {
 	client := protocol.ClientDispatcher(clientConn)
 
-	netConn, err := f.connectToRemote(ctx)
+	netConn, err := f.dialer.dialNet(ctx)
 	if err != nil {
 		return errors.Errorf("forwarder: connecting to remote: %w", err)
 	}
@@ -293,19 +233,19 @@
 	return err
 }
 
-func (f *Forwarder) Binder() *ForwardBinder {
-	network, address := realNetworkAddress(f.network, f.addr, f.goplsPath)
-	dialer := jsonrpc2_v2.NetDialer(network, address, net.Dialer{
-		Timeout: 5 * time.Second,
-	})
-	return NewForwardBinder(dialer)
-}
-
+// TODO(rfindley): remove this handshaking in favor of middleware.
 func (f *Forwarder) handshake(ctx context.Context) {
+	// This call to os.Execuable is redundant, and will be eliminated by the
+	// transition to the V2 API.
+	goplsPath, err := os.Executable()
+	if err != nil {
+		event.Error(ctx, "getting executable for handshake", err)
+		goplsPath = ""
+	}
 	var (
 		hreq = handshakeRequest{
 			ServerID:  f.serverID,
-			GoplsPath: f.goplsPath,
+			GoplsPath: goplsPath,
 		}
 		hresp handshakeResponse
 	)
@@ -318,8 +258,8 @@
 		// here.  Handshakes have become functional in nature.
 		event.Error(ctx, "forwarder: gopls handshake failed", err)
 	}
-	if hresp.GoplsPath != f.goplsPath {
-		event.Error(ctx, "", fmt.Errorf("forwarder: gopls path mismatch: forwarder is %q, remote is %q", f.goplsPath, hresp.GoplsPath))
+	if hresp.GoplsPath != goplsPath {
+		event.Error(ctx, "", fmt.Errorf("forwarder: gopls path mismatch: forwarder is %q, remote is %q", goplsPath, hresp.GoplsPath))
 	}
 	event.Log(ctx, "New server",
 		tag.NewServer.Of(f.serverID),
@@ -330,108 +270,12 @@
 	)
 }
 
-func (f *Forwarder) connectToRemote(ctx context.Context) (net.Conn, error) {
-	return connectToRemote(ctx, f.network, f.addr, f.goplsPath, f.remoteConfig)
-}
-
-func ConnectToRemote(ctx context.Context, addr string, opts ...RemoteOption) (net.Conn, error) {
-	rcfg := defaultRemoteConfig()
-	for _, opt := range opts {
-		opt.set(&rcfg)
-	}
-	// This is not strictly necessary, as it won't be used if not connecting to
-	// the 'auto' remote.
-	goplsPath, err := os.Executable()
+func ConnectToRemote(ctx context.Context, addr string) (net.Conn, error) {
+	dialer, err := NewAutoDialer(addr, nil)
 	if err != nil {
-		return nil, fmt.Errorf("unable to resolve gopls path: %v", err)
+		return nil, err
 	}
-	network, address := ParseAddr(addr)
-	return connectToRemote(ctx, network, address, goplsPath, rcfg)
-}
-
-func realNetworkAddress(inNetwork, inAddr, goplsPath string) (network, address string) {
-	if inNetwork != AutoNetwork {
-		return inNetwork, inAddr
-	}
-	// The "auto" network is a fake network used for service discovery. It
-	// resolves a known address based on gopls binary path.
-	return autoNetworkAddress(goplsPath, inAddr)
-}
-
-func connectToRemote(ctx context.Context, inNetwork, inAddr, goplsPath string, rcfg remoteConfig) (net.Conn, error) {
-	var (
-		netConn          net.Conn
-		err              error
-		network, address = realNetworkAddress(inNetwork, inAddr, goplsPath)
-	)
-	// Attempt to verify that we own the remote. This is imperfect, but if we can
-	// determine that the remote is owned by a different user, we should fail.
-	ok, err := verifyRemoteOwnership(network, address)
-	if err != nil {
-		// If the ownership check itself failed, we fail open but log an error to
-		// the user.
-		event.Error(ctx, "unable to check daemon socket owner, failing open", err)
-	} else if !ok {
-		// We successfully checked that the socket is not owned by us, we fail
-		// closed.
-		return nil, fmt.Errorf("socket %q is owned by a different user", address)
-	}
-	const dialTimeout = 1 * time.Second
-	// Try dialing our remote once, in case it is already running.
-	netConn, err = net.DialTimeout(network, address, dialTimeout)
-	if err == nil {
-		return netConn, nil
-	}
-	// If our remote is on the 'auto' network, start it if it doesn't exist.
-	if inNetwork == AutoNetwork {
-		if goplsPath == "" {
-			return nil, fmt.Errorf("cannot auto-start remote: gopls path is unknown")
-		}
-		if network == "unix" {
-			// Sometimes the socketfile isn't properly cleaned up when gopls shuts
-			// down. Since we have already tried and failed to dial this address, it
-			// should *usually* be safe to remove the socket before binding to the
-			// address.
-			// TODO(rfindley): there is probably a race here if multiple gopls
-			// instances are simultaneously starting up.
-			if _, err := os.Stat(address); err == nil {
-				if err := os.Remove(address); err != nil {
-					return nil, errors.Errorf("removing remote socket file: %w", err)
-				}
-			}
-		}
-		args := []string{"serve",
-			"-listen", fmt.Sprintf(`%s;%s`, network, address),
-			"-listen.timeout", rcfg.listenTimeout.String(),
-		}
-		if rcfg.logfile != "" {
-			args = append(args, "-logfile", rcfg.logfile)
-		}
-		if rcfg.debug != "" {
-			args = append(args, "-debug", rcfg.debug)
-		}
-		if err := startRemote(goplsPath, args...); err != nil {
-			return nil, errors.Errorf("startRemote(%q, %v): %w", goplsPath, args, err)
-		}
-	}
-
-	const retries = 5
-	// It can take some time for the newly started server to bind to our address,
-	// so we retry for a bit.
-	for retry := 0; retry < retries; retry++ {
-		startDial := time.Now()
-		netConn, err = net.DialTimeout(network, address, dialTimeout)
-		if err == nil {
-			return netConn, nil
-		}
-		event.Log(ctx, fmt.Sprintf("failed attempt #%d to connect to remote: %v\n", retry+2, err))
-		// In case our failure was a fast-failure, ensure we wait at least
-		// f.dialTimeout before trying again.
-		if retry != retries-1 {
-			time.Sleep(dialTimeout - time.Since(startDial))
-		}
-	}
-	return nil, errors.Errorf("dialing remote: %w", err)
+	return dialer.dialNet(ctx)
 }
 
 // handler intercepts messages to the daemon to enrich them with local
diff --git a/internal/lsp/lsprpc/lsprpc_test.go b/internal/lsp/lsprpc/lsprpc_test.go
index b2902fa..24decbe 100644
--- a/internal/lsp/lsprpc/lsprpc_test.go
+++ b/internal/lsp/lsprpc/lsprpc_test.go
@@ -126,7 +126,10 @@
 	tsDirect := servertest.NewTCPServer(serveCtx, ss, nil)
 
 	forwarderCtx := debug.WithInstance(ctx, "", "")
-	forwarder := NewForwarder("tcp", tsDirect.Addr)
+	forwarder, err := NewForwarder("tcp;"+tsDirect.Addr, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
 	tsForwarded := servertest.NewPipeServer(forwarderCtx, forwarder, nil)
 	return tsDirect, tsForwarded, func() {
 		checkClose(t, tsDirect.Close)
@@ -218,7 +221,10 @@
 	ss := NewStreamServer(cache, false)
 	tsBackend := servertest.NewTCPServer(serverCtx, ss, nil)
 
-	forwarder := NewForwarder("tcp", tsBackend.Addr)
+	forwarder, err := NewForwarder("tcp;"+tsBackend.Addr, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
 	tsForwarder := servertest.NewPipeServer(clientCtx, forwarder, nil)
 
 	conn1 := tsForwarder.Connect(clientCtx)
diff --git a/internal/lsp/regtest/runner.go b/internal/lsp/regtest/runner.go
index fb17b49..6b3501c 100644
--- a/internal/lsp/regtest/runner.go
+++ b/internal/lsp/regtest/runner.go
@@ -415,7 +415,7 @@
 
 func (r *Runner) forwardedServer(ctx context.Context, t *testing.T, optsHook func(*source.Options)) jsonrpc2.StreamServer {
 	ts := r.getTestServer(optsHook)
-	return lsprpc.NewForwarder("tcp", ts.Addr)
+	return newForwarder("tcp", ts.Addr)
 }
 
 // getTestServer gets the shared test server instance to connect to, or creates
@@ -436,7 +436,16 @@
 	// TODO(rfindley): can we use the autostart behavior here, instead of
 	// pre-starting the remote?
 	socket := r.getRemoteSocket(t)
-	return lsprpc.NewForwarder("unix", socket)
+	return newForwarder("unix", socket)
+}
+
+func newForwarder(network, address string) *lsprpc.Forwarder {
+	server, err := lsprpc.NewForwarder(network+";"+address, nil)
+	if err != nil {
+		// This should never happen, as we are passing an explicit address.
+		panic(fmt.Sprintf("internal error: unable to create forwarder: %v", err))
+	}
+	return server
 }
 
 // runTestAsGoplsEnvvar triggers TestMain to run gopls instead of running