internal/imports,lsp: use callbacks for completion functions

We only need to return a relatively small number of completions to the
user. There's no point continuing once we have those, so switch the
completion functions to be callback-based, and cancel once we've got
what we want.

Change-Id: Ied199fb1f41346819c7237dfed8251fa3ac73ad7
Reviewed-on: https://go-review.googlesource.com/c/tools/+/212634
Run-TryBot: Heschi Kreinick <heschi@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/imports/fix.go b/internal/imports/fix.go
index 7ae035f..07536d2 100644
--- a/internal/imports/fix.go
+++ b/internal/imports/fix.go
@@ -82,7 +82,7 @@
 	IdentName string
 	// FixType is the type of fix this is (AddImport, DeleteImport, SetImportName).
 	FixType   ImportFixType
-	relevance int // see pkg
+	Relevance int // see pkg
 }
 
 // An ImportInfo represents a single import statement.
@@ -585,6 +585,10 @@
 	return fixes, nil
 }
 
+// Highest relevance, used for the standard library. Chosen arbitrarily to
+// match pre-existing gopls code.
+const MaxRelevance = 7
+
 // getCandidatePkgs returns the list of pkgs that are accessible from filename,
 // filtered to those that match pkgnameFilter.
 func getCandidatePkgs(ctx context.Context, wrappedCallback *scanCallback, filename string, env *ProcessEnv) error {
@@ -596,7 +600,7 @@
 			dir:             filepath.Join(env.GOROOT, "src", importPath),
 			importPathShort: importPath,
 			packageName:     path.Base(importPath),
-			relevance:       0,
+			relevance:       MaxRelevance,
 		}
 		if wrappedCallback.packageNameLoaded(p) {
 			wrappedCallback.exportsLoaded(p, exports)
@@ -630,17 +634,6 @@
 	return env.GetResolver().scan(ctx, scanFilter, exclude)
 }
 
-// Compare first by relevance, then by package name, with import path as a tiebreaker.
-func compareFix(fi, fj *ImportFix) bool {
-	if fi.relevance != fj.relevance {
-		return fi.relevance < fj.relevance
-	}
-	if fi.IdentName != fj.IdentName {
-		return fi.IdentName < fj.IdentName
-	}
-	return fi.StmtInfo.ImportPath < fj.StmtInfo.ImportPath
-}
-
 func candidateImportName(pkg *pkg) string {
 	if ImportPathToAssumedName(pkg.importPathShort) != pkg.packageName {
 		return pkg.packageName
@@ -649,39 +642,32 @@
 }
 
 // getAllCandidates gets all of the candidates to be imported, regardless of if they are needed.
-func getAllCandidates(ctx context.Context, prefix string, filename string, env *ProcessEnv) ([]ImportFix, error) {
-	var mu sync.Mutex
-	var results []ImportFix
-	filter := &scanCallback{
+func getAllCandidates(ctx context.Context, wrapped func(ImportFix), prefix string, filename string, env *ProcessEnv) error {
+	callback := &scanCallback{
 		dirFound: func(pkg *pkg) bool {
 			// TODO(heschi): apply dir match heuristics like pkgIsCandidate
 			return true
 		},
 		packageNameLoaded: func(pkg *pkg) bool {
 			if strings.HasPrefix(pkg.packageName, prefix) {
-				mu.Lock()
-				defer mu.Unlock()
-				results = append(results, ImportFix{
+				wrapped(ImportFix{
 					StmtInfo: ImportInfo{
 						ImportPath: pkg.importPathShort,
 						Name:       candidateImportName(pkg),
 					},
 					IdentName: pkg.packageName,
 					FixType:   AddImport,
-					relevance: pkg.relevance,
+					Relevance: pkg.relevance,
 				})
 			}
 			return false
 		},
 	}
-	err := getCandidatePkgs(ctx, filter, filename, env)
+	err := getCandidatePkgs(ctx, callback, filename, env)
 	if err != nil {
-		return nil, err
+		return err
 	}
-	sort.Slice(results, func(i, j int) bool {
-		return compareFix(&results[i], &results[j])
-	})
-	return results, nil
+	return nil
 }
 
 // A PackageExport is a package and its exports.
@@ -690,9 +676,7 @@
 	Exports []string
 }
 
-func getPackageExports(ctx context.Context, completePackage, filename string, env *ProcessEnv) ([]PackageExport, error) {
-	var mu sync.Mutex
-	var results []PackageExport
+func getPackageExports(ctx context.Context, wrapped func(PackageExport), completePackage, filename string, env *ProcessEnv) error {
 	callback := &scanCallback{
 		dirFound: func(pkg *pkg) bool {
 			// TODO(heschi): apply dir match heuristics like pkgIsCandidate
@@ -702,10 +686,8 @@
 			return pkg.packageName == completePackage
 		},
 		exportsLoaded: func(pkg *pkg, exports []string) {
-			mu.Lock()
-			defer mu.Unlock()
 			sort.Strings(exports)
-			results = append(results, PackageExport{
+			wrapped(PackageExport{
 				Fix: &ImportFix{
 					StmtInfo: ImportInfo{
 						ImportPath: pkg.importPathShort,
@@ -713,7 +695,7 @@
 					},
 					IdentName: pkg.packageName,
 					FixType:   AddImport,
-					relevance: pkg.relevance,
+					Relevance: pkg.relevance,
 				},
 				Exports: exports,
 			})
@@ -721,12 +703,9 @@
 	}
 	err := getCandidatePkgs(ctx, callback, filename, env)
 	if err != nil {
-		return nil, err
+		return err
 	}
-	sort.Slice(results, func(i, j int) bool {
-		return compareFix(results[i].Fix, results[j].Fix)
-	})
-	return results, nil
+	return nil
 }
 
 // ProcessEnv contains environment variables and settings that affect the use of
@@ -1175,10 +1154,10 @@
 		p := &pkg{
 			importPathShort: info.nonCanonicalImportPath,
 			dir:             dir,
-			relevance:       1,
+			relevance:       MaxRelevance - 1,
 		}
 		if info.rootType == gopathwalk.RootGOROOT {
-			p.relevance = 0
+			p.relevance = MaxRelevance
 		}
 
 		if callback.dirFound(p) {
diff --git a/internal/imports/fix_test.go b/internal/imports/fix_test.go
index b48a69a..9e1dff8 100644
--- a/internal/imports/fix_test.go
+++ b/internal/imports/fix_test.go
@@ -13,6 +13,7 @@
 	"path/filepath"
 	"reflect"
 	"runtime"
+	"sort"
 	"strings"
 	"sync"
 	"testing"
@@ -2492,15 +2493,15 @@
 // with correct priorities.
 func TestGetCandidates(t *testing.T) {
 	type res struct {
+		relevance  int
 		name, path string
 	}
 	want := []res{
-		{"bytes", "bytes"},
-		{"http", "net/http"},
-		{"rand", "crypto/rand"},
-		{"rand", "math/rand"},
-		{"bar", "bar.com/bar"},
-		{"foo", "foo.com/foo"},
+		{0, "bytes", "bytes"},
+		{0, "http", "net/http"},
+		{0, "rand", "crypto/rand"},
+		{0, "bar", "bar.com/bar"},
+		{0, "foo", "foo.com/foo"},
 	}
 
 	testConfig{
@@ -2515,18 +2516,31 @@
 			},
 		},
 	}.test(t, func(t *goimportTest) {
-		candidates, err := getAllCandidates(context.Background(), "", "x.go", t.env)
-		if err != nil {
-			t.Fatalf("GetAllCandidates() = %v", err)
-		}
+		var mu sync.Mutex
 		var got []res
-		for _, c := range candidates {
+		add := func(c ImportFix) {
+			mu.Lock()
+			defer mu.Unlock()
 			for _, w := range want {
 				if c.StmtInfo.ImportPath == w.path {
-					got = append(got, res{c.IdentName, c.StmtInfo.ImportPath})
+					got = append(got, res{c.Relevance, c.IdentName, c.StmtInfo.ImportPath})
 				}
 			}
 		}
+		if err := getAllCandidates(context.Background(), add, "", "x.go", t.env); err != nil {
+			t.Fatalf("GetAllCandidates() = %v", err)
+		}
+		// Sort, then clear out relevance so it doesn't mess up the DeepEqual.
+		sort.Slice(got, func(i, j int) bool {
+			ri, rj := got[i], got[j]
+			if ri.relevance != rj.relevance {
+				return ri.relevance > rj.relevance // Highest first.
+			}
+			return ri.name < rj.name
+		})
+		for i := range got {
+			got[i].relevance = 0
+		}
 		if !reflect.DeepEqual(want, got) {
 			t.Errorf("wanted stdlib results in order %v, got %v", want, got)
 		}
@@ -2535,12 +2549,12 @@
 
 func TestGetPackageCompletions(t *testing.T) {
 	type res struct {
+		relevance          int
 		name, path, symbol string
 	}
 	want := []res{
-		{"rand", "crypto/rand", "Prime"},
-		{"rand", "math/rand", "Seed"},
-		{"rand", "bar.com/rand", "Bar"},
+		{0, "rand", "math/rand", "Seed"},
+		{0, "rand", "bar.com/rand", "Bar"},
 	}
 
 	testConfig{
@@ -2551,20 +2565,33 @@
 			},
 		},
 	}.test(t, func(t *goimportTest) {
-		candidates, err := getPackageExports(context.Background(), "rand", "x.go", t.env)
-		if err != nil {
-			t.Fatalf("getPackageCompletions() = %v", err)
-		}
+		var mu sync.Mutex
 		var got []res
-		for _, c := range candidates {
+		add := func(c PackageExport) {
+			mu.Lock()
+			defer mu.Unlock()
 			for _, csym := range c.Exports {
 				for _, w := range want {
 					if c.Fix.StmtInfo.ImportPath == w.path && csym == w.symbol {
-						got = append(got, res{c.Fix.IdentName, c.Fix.StmtInfo.ImportPath, csym})
+						got = append(got, res{c.Fix.Relevance, c.Fix.IdentName, c.Fix.StmtInfo.ImportPath, csym})
 					}
 				}
 			}
 		}
+		if err := getPackageExports(context.Background(), add, "rand", "x.go", t.env); err != nil {
+			t.Fatalf("getPackageCompletions() = %v", err)
+		}
+		// Sort, then clear out relevance so it doesn't mess up the DeepEqual.
+		sort.Slice(got, func(i, j int) bool {
+			ri, rj := got[i], got[j]
+			if ri.relevance != rj.relevance {
+				return ri.relevance > rj.relevance // Highest first.
+			}
+			return ri.name < rj.name
+		})
+		for i := range got {
+			got[i].relevance = 0
+		}
 		if !reflect.DeepEqual(want, got) {
 			t.Errorf("wanted stdlib results in order %v, got %v", want, got)
 		}
diff --git a/internal/imports/imports.go b/internal/imports/imports.go
index c857043..3855c8a 100644
--- a/internal/imports/imports.go
+++ b/internal/imports/imports.go
@@ -118,21 +118,21 @@
 
 // GetAllCandidates gets all of the packages starting with prefix that can be
 // imported by filename, sorted by import path.
-func GetAllCandidates(ctx context.Context, prefix string, filename string, opt *Options) (pkgs []ImportFix, err error) {
-	_, opt, err = initialize(filename, nil, opt)
+func GetAllCandidates(ctx context.Context, callback func(ImportFix), prefix string, filename string, opt *Options) error {
+	_, opt, err := initialize(filename, nil, opt)
 	if err != nil {
-		return nil, err
+		return err
 	}
-	return getAllCandidates(ctx, prefix, filename, opt.Env)
+	return getAllCandidates(ctx, callback, prefix, filename, opt.Env)
 }
 
 // GetPackageExports returns all known packages with name pkg and their exports.
-func GetPackageExports(ctx context.Context, pkg, filename string, opt *Options) (exports []PackageExport, err error) {
-	_, opt, err = initialize(filename, nil, opt)
+func GetPackageExports(ctx context.Context, callback func(PackageExport), pkg, filename string, opt *Options) error {
+	_, opt, err := initialize(filename, nil, opt)
 	if err != nil {
-		return nil, err
+		return err
 	}
-	return getPackageExports(ctx, pkg, filename, opt.Env)
+	return getPackageExports(ctx, callback, pkg, filename, opt.Env)
 }
 
 // initialize sets the values for opt and src.
diff --git a/internal/imports/mod.go b/internal/imports/mod.go
index 2632937..fb665a3 100644
--- a/internal/imports/mod.go
+++ b/internal/imports/mod.go
@@ -425,18 +425,18 @@
 			importPathShort: info.nonCanonicalImportPath,
 			dir:             info.dir,
 			packageName:     path.Base(info.nonCanonicalImportPath),
-			relevance:       0,
+			relevance:       MaxRelevance,
 		}, nil
 	}
 
 	importPath := info.nonCanonicalImportPath
-	relevance := 3
+	relevance := MaxRelevance - 3
 	// Check if the directory is underneath a module that's in scope.
 	if mod := r.findModuleByDir(info.dir); mod != nil {
 		if mod.Indirect {
-			relevance = 2
+			relevance = MaxRelevance - 2
 		} else {
-			relevance = 1
+			relevance = MaxRelevance - 1
 		}
 		// It is. If dir is the target of a replace directive,
 		// our guessed import path is wrong. Use the real one.
diff --git a/internal/imports/mod_test.go b/internal/imports/mod_test.go
index 61a6beb..e98f222 100644
--- a/internal/imports/mod_test.go
+++ b/internal/imports/mod_test.go
@@ -13,6 +13,7 @@
 	"path/filepath"
 	"reflect"
 	"regexp"
+	"sort"
 	"strings"
 	"sync"
 	"testing"
@@ -871,32 +872,42 @@
 	}
 
 	type res struct {
+		relevance  int
 		name, path string
 	}
 	want := []res{
 		// Stdlib
-		{"bytes", "bytes"},
-		{"http", "net/http"},
+		{7, "bytes", "bytes"},
+		{7, "http", "net/http"},
 		// Direct module deps
-		{"quote", "rsc.io/quote"},
+		{6, "quote", "rsc.io/quote"},
+		{6, "rpackage", "example.com/rpackage"},
 		// Indirect deps
-		{"rpackage", "example.com/rpackage"},
-		{"language", "golang.org/x/text/language"},
+		{5, "language", "golang.org/x/text/language"},
 		// Out of scope modules
-		{"quote", "rsc.io/quote/v2"},
+		{4, "quote", "rsc.io/quote/v2"},
 	}
-	candidates, err := getAllCandidates(context.Background(), "", "foo.go", mt.env)
-	if err != nil {
-		t.Fatalf("getAllCandidates() = %v", err)
-	}
+	var mu sync.Mutex
 	var got []res
-	for _, c := range candidates {
+	add := func(c ImportFix) {
+		mu.Lock()
+		defer mu.Unlock()
 		for _, w := range want {
 			if c.StmtInfo.ImportPath == w.path {
-				got = append(got, res{c.IdentName, c.StmtInfo.ImportPath})
+				got = append(got, res{c.Relevance, c.IdentName, c.StmtInfo.ImportPath})
 			}
 		}
 	}
+	if err := getAllCandidates(context.Background(), add, "", "foo.go", mt.env); err != nil {
+		t.Fatalf("getAllCandidates() = %v", err)
+	}
+	sort.Slice(got, func(i, j int) bool {
+		ri, rj := got[i], got[j]
+		if ri.relevance != rj.relevance {
+			return ri.relevance > rj.relevance // Highest first.
+		}
+		return ri.name < rj.name
+	})
 	if !reflect.DeepEqual(want, got) {
 		t.Errorf("wanted candidates in order %v, got %v", want, got)
 	}
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 767181d..ef2318e 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -14,6 +14,7 @@
 	"math"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
 	"golang.org/x/tools/go/ast/astutil"
@@ -245,6 +246,13 @@
 	return p.content[p.cursor-p.spanRange.Start:]
 }
 
+func (c *completer) deepCompletionContext() (context.Context, context.CancelFunc) {
+	if c.opts.Budget == 0 {
+		return context.WithCancel(c.ctx)
+	}
+	return context.WithDeadline(c.ctx, c.startTime.Add(c.opts.Budget))
+}
+
 func (c *completer) setSurrounding(ident *ast.Ident) {
 	if c.surrounding != nil {
 		return
@@ -622,15 +630,11 @@
 
 	// Try unimported packages.
 	if id, ok := sel.X.(*ast.Ident); ok && c.opts.Unimported && len(c.items) < unimportedTarget {
-		pkgExports, err := PackageExports(c.ctx, c.snapshot.View(), id.Name, c.filename)
-		if err != nil {
-			return err
-		}
+		ctx, cancel := c.deepCompletionContext()
+		defer cancel()
+
 		known := c.snapshot.KnownImportPaths()
-		for _, pkgExport := range pkgExports {
-			if len(c.items) >= unimportedTarget {
-				break
-			}
+		add := func(pkgExport imports.PackageExport) {
 			// If we've seen this import path, use the fully-typed version.
 			if knownPkg, ok := known[pkgExport.Fix.StmtInfo.ImportPath]; ok {
 				c.packageMembers(knownPkg.GetTypes(), &importInfo{
@@ -638,22 +642,30 @@
 					name:       pkgExport.Fix.StmtInfo.Name,
 					pkg:        knownPkg,
 				})
-				continue
+			} else {
+				// Otherwise, continue with untyped proposals.
+				pkg := types.NewPackage(pkgExport.Fix.StmtInfo.ImportPath, pkgExport.Fix.IdentName)
+				for _, export := range pkgExport.Exports {
+					score := 0.01 * float64(pkgExport.Fix.Relevance)
+					c.found(candidate{
+						obj:   types.NewVar(0, pkg, export, nil),
+						score: score,
+						imp: &importInfo{
+							importPath: pkgExport.Fix.StmtInfo.ImportPath,
+							name:       pkgExport.Fix.StmtInfo.Name,
+						},
+					})
+				}
 			}
-
-			// Otherwise, continue with untyped proposals.
-			pkg := types.NewPackage(pkgExport.Fix.StmtInfo.ImportPath, pkgExport.Fix.IdentName)
-			for _, export := range pkgExport.Exports {
-				c.found(candidate{
-					obj:   types.NewVar(0, pkg, export, nil),
-					score: 0.07,
-					imp: &importInfo{
-						importPath: pkgExport.Fix.StmtInfo.ImportPath,
-						name:       pkgExport.Fix.StmtInfo.Name,
-					},
-				})
+			if len(c.items) >= unimportedTarget {
+				cancel()
 			}
 		}
+		if err := c.snapshot.View().RunProcessEnvFunc(ctx, func(opts *imports.Options) error {
+			return imports.GetPackageExports(ctx, add, id.Name, c.filename, opts)
+		}); err != nil && err != context.Canceled {
+			return err
+		}
 	}
 	return nil
 }
@@ -830,39 +842,52 @@
 	}
 
 	if c.opts.Unimported && len(c.items) < unimportedTarget {
-		ctx, cancel := context.WithDeadline(c.ctx, c.startTime.Add(c.opts.Budget))
+		ctx, cancel := c.deepCompletionContext()
 		defer cancel()
 		// Suggest packages that have not been imported yet.
 		prefix := ""
 		if c.surrounding != nil {
 			prefix = c.surrounding.Prefix()
 		}
-		pkgs, err := CandidateImports(ctx, prefix, c.snapshot.View(), c.filename)
-		if err != nil {
-			return err
-		}
-		score := stdScore
-		// Rank unimported packages significantly lower than other results.
-		score *= 0.07
+		var mu sync.Mutex
+		add := func(pkg imports.ImportFix) {
+			mu.Lock()
+			defer mu.Unlock()
+			if _, ok := seen[pkg.IdentName]; ok {
+				return
+			}
+			// Rank unimported packages significantly lower than other results.
+			score := 0.01 * float64(pkg.Relevance)
 
-		for _, pkg := range pkgs {
+			// Do not add the unimported packages to seen, since we can have
+			// multiple packages of the same name as completion suggestions, since
+			// only one will be chosen.
+			obj := types.NewPkgName(0, nil, pkg.IdentName, types.NewPackage(pkg.StmtInfo.ImportPath, pkg.IdentName))
+			c.found(candidate{
+				obj:   obj,
+				score: score,
+				imp: &importInfo{
+					importPath: pkg.StmtInfo.ImportPath,
+					name:       pkg.StmtInfo.Name,
+				},
+			})
+
 			if len(c.items) >= unimportedTarget {
-				break
+				cancel()
 			}
-			if _, ok := seen[pkg.IdentName]; !ok {
-				// Do not add the unimported packages to seen, since we can have
-				// multiple packages of the same name as completion suggestions, since
-				// only one will be chosen.
-				obj := types.NewPkgName(0, nil, pkg.IdentName, types.NewPackage(pkg.StmtInfo.ImportPath, pkg.IdentName))
-				c.found(candidate{
-					obj:   obj,
-					score: score,
-					imp: &importInfo{
-						importPath: pkg.StmtInfo.ImportPath,
-						name:       pkg.StmtInfo.Name,
-					},
-				})
-			}
+			c.found(candidate{
+				obj:   obj,
+				score: score,
+				imp: &importInfo{
+					importPath: pkg.StmtInfo.ImportPath,
+					name:       pkg.StmtInfo.Name,
+				},
+			})
+		}
+		if err := c.snapshot.View().RunProcessEnvFunc(ctx, func(opts *imports.Options) error {
+			return imports.GetAllCandidates(ctx, add, prefix, c.filename, opts)
+		}); err != nil && err != context.Canceled {
+			return err
 		}
 	}
 
diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go
index e11e2cd..406c795 100644
--- a/internal/lsp/source/format.go
+++ b/internal/lsp/source/format.go
@@ -303,37 +303,6 @@
 	return src[0:fset.Position(end).Offset], true
 }
 
-// CandidateImports returns every import that could be added to filename.
-func CandidateImports(ctx context.Context, prefix string, view View, filename string) ([]imports.ImportFix, error) {
-	ctx, done := trace.StartSpan(ctx, "source.CandidateImports")
-	defer done()
-
-	var imps []imports.ImportFix
-	importFn := func(opts *imports.Options) error {
-		var err error
-		imps, err = imports.GetAllCandidates(ctx, prefix, filename, opts)
-		return err
-	}
-	err := view.RunProcessEnvFunc(ctx, importFn)
-	return imps, err
-}
-
-// PackageExports returns all the packages named pkg that could be imported by
-// filename, and their exports.
-func PackageExports(ctx context.Context, view View, pkg, filename string) ([]imports.PackageExport, error) {
-	ctx, done := trace.StartSpan(ctx, "source.PackageExports")
-	defer done()
-
-	var pkgs []imports.PackageExport
-	importFn := func(opts *imports.Options) error {
-		var err error
-		pkgs, err = imports.GetPackageExports(ctx, pkg, filename, opts)
-		return err
-	}
-	err := view.RunProcessEnvFunc(ctx, importFn)
-	return pkgs, err
-}
-
 // hasParseErrors returns true if the given file has parse errors.
 func hasParseErrors(pkg Package, uri span.URI) bool {
 	for _, e := range pkg.GetErrors() {
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 734b201..3b9d129 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -20,6 +20,7 @@
 	"strings"
 	"sync"
 	"testing"
+	"time"
 
 	"golang.org/x/tools/go/expect"
 	"golang.org/x/tools/go/packages"
@@ -190,6 +191,7 @@
 	}
 	o.HoverKind = source.SynopsisDocumentation
 	o.InsertTextFormat = protocol.SnippetTextFormat
+	o.Completion.Budget = time.Minute
 	return o
 }