internal/lsp/source: speed up completion candidate formatting

Completion could be slow due to calls to astutil.PathEnclosingInterval
for every candidate during formatting. There were two reasons we
called PEI:

1. To properly render type alias names, we must refer to the AST
   because the alias name is not available in the typed world.
   Previously we would call PEI to find the *type.Var's
   corresponding *ast.Field, but now we have a PosToField cache that
   lets us jump straight from the types.Object's token.Pos to the
   corresponding *ast.Field.

2. To display an object's documentation we must refer to the AST. We
   need the object's declaring node and any containing ast.Decl. We
   now maintain a special PosToDecl cache so we can avoid the PEI call
   in this case as well.

We can't use a single cache for both because the *ast.Field's position
is present in both caches (but points to different nodes). The caches
are memoized to defer generation until they are needed and to save
work creating them if the *ast.Files haven't changed.

These changes speed up completing the fields of
github.com/aws/aws-sdk-go/service/ec2 from 18.5s to 45ms on my laptop.

Fixes golang/go#37450.

Change-Id: I25cc5ea39551db728a2348f346342ebebeddd049
Reviewed-on: https://go-review.googlesource.com/c/tools/+/221021
Run-TryBot: Muir Manders <muir@mnd.rs>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go
index a71106c..2c2e708 100644
--- a/internal/lsp/cache/check.go
+++ b/internal/lsp/cache/check.go
@@ -92,8 +92,10 @@
 				dep.check(ctx)
 			}(dep)
 		}
+
 		data := &packageData{}
 		data.pkg, data.err = typeCheck(ctx, fset, m, mode, goFiles, compiledGoFiles, deps)
+
 		return data
 	})
 	ph.handle = h
@@ -413,6 +415,7 @@
 			}
 		}
 	}
+
 	return pkg, nil
 }
 
diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go
index a3b1427..75d8888 100644
--- a/internal/lsp/cache/parse.go
+++ b/internal/lsp/cache/parse.go
@@ -28,10 +28,15 @@
 	mode source.ParseMode
 }
 
+// astCacheKey is similar to parseKey, but is a distinct type because
+// it is used to key a different value within the same map.
+type astCacheKey parseKey
+
 type parseGoHandle struct {
-	handle *memoize.Handle
-	file   source.FileHandle
-	mode   source.ParseMode
+	handle         *memoize.Handle
+	file           source.FileHandle
+	mode           source.ParseMode
+	astCacheHandle *memoize.Handle
 }
 
 type parseGoData struct {
@@ -63,10 +68,14 @@
 	h := c.store.Bind(key, func(ctx context.Context) interface{} {
 		return parseGo(ctx, fset, fh, mode)
 	})
+
 	return &parseGoHandle{
 		handle: h,
 		file:   fh,
 		mode:   mode,
+		astCacheHandle: c.store.Bind(astCacheKey(key), func(ctx context.Context) interface{} {
+			return buildASTCache(ctx, h)
+		}),
 	}
 }
 
@@ -111,6 +120,133 @@
 	return data.ast, data.src, data.mapper, data.parseError, data.err
 }
 
+func (pgh *parseGoHandle) PosToDecl(ctx context.Context) (map[token.Pos]ast.Decl, error) {
+	v, err := pgh.astCacheHandle.Get(ctx)
+	if err != nil || v == nil {
+		return nil, err
+	}
+
+	data := v.(*astCacheData)
+	if data.err != nil {
+		return nil, data.err
+	}
+
+	return data.posToDecl, nil
+}
+
+func (pgh *parseGoHandle) PosToField(ctx context.Context) (map[token.Pos]*ast.Field, error) {
+	v, err := pgh.astCacheHandle.Get(ctx)
+	if err != nil || v == nil {
+		return nil, err
+	}
+
+	data := v.(*astCacheData)
+	if data.err != nil {
+		return nil, data.err
+	}
+
+	return data.posToField, nil
+}
+
+type astCacheData struct {
+	memoize.NoCopy
+
+	err error
+
+	posToDecl  map[token.Pos]ast.Decl
+	posToField map[token.Pos]*ast.Field
+}
+
+// buildASTCache builds caches to aid in quickly going from the typed
+// world to the syntactic world.
+func buildASTCache(ctx context.Context, parseHandle *memoize.Handle) *astCacheData {
+	var (
+		// path contains all ancestors, including n.
+		path []ast.Node
+		// decls contains all ancestors that are decls.
+		decls []ast.Decl
+	)
+
+	v, err := parseHandle.Get(ctx)
+	if err != nil || v == nil || v.(*parseGoData).ast == nil {
+		return &astCacheData{err: err}
+	}
+
+	data := &astCacheData{
+		posToDecl:  make(map[token.Pos]ast.Decl),
+		posToField: make(map[token.Pos]*ast.Field),
+	}
+
+	ast.Inspect(v.(*parseGoData).ast, func(n ast.Node) bool {
+		if n == nil {
+			lastP := path[len(path)-1]
+			path = path[:len(path)-1]
+			if len(decls) > 0 && decls[len(decls)-1] == lastP {
+				decls = decls[:len(decls)-1]
+			}
+			return false
+		}
+
+		path = append(path, n)
+
+		switch n := n.(type) {
+		case *ast.Field:
+			addField := func(f ast.Node) {
+				if f.Pos().IsValid() {
+					data.posToField[f.Pos()] = n
+					if len(decls) > 0 {
+						data.posToDecl[f.Pos()] = decls[len(decls)-1]
+					}
+				}
+			}
+
+			// Add mapping for *ast.Field itself. This handles embedded
+			// fields which have no associated *ast.Ident name.
+			addField(n)
+
+			// Add mapping for each field name since you can have
+			// multiple names for the same type expression.
+			for _, name := range n.Names {
+				addField(name)
+			}
+
+			// Also map "X" in "...X" to the containing *ast.Field. This
+			// makes it easy to format variadic signature params
+			// properly.
+			if elips, ok := n.Type.(*ast.Ellipsis); ok && elips.Elt != nil {
+				addField(elips.Elt)
+			}
+		case *ast.FuncDecl:
+			decls = append(decls, n)
+
+			if n.Name != nil && n.Name.Pos().IsValid() {
+				data.posToDecl[n.Name.Pos()] = n
+			}
+		case *ast.GenDecl:
+			decls = append(decls, n)
+
+			for _, spec := range n.Specs {
+				switch spec := spec.(type) {
+				case *ast.TypeSpec:
+					if spec.Name != nil && spec.Name.Pos().IsValid() {
+						data.posToDecl[spec.Name.Pos()] = n
+					}
+				case *ast.ValueSpec:
+					for _, id := range spec.Names {
+						if id != nil && id.Pos().IsValid() {
+							data.posToDecl[id.Pos()] = n
+						}
+					}
+				}
+			}
+		}
+
+		return true
+	})
+
+	return data
+}
+
 func hashParseKeys(pghs []*parseGoHandle) string {
 	b := bytes.NewBuffer(nil)
 	for _, pgh := range pghs {
diff --git a/internal/lsp/source/completion_format.go b/internal/lsp/source/completion_format.go
index ffa85be..664f8ff 100644
--- a/internal/lsp/source/completion_format.go
+++ b/internal/lsp/source/completion_format.go
@@ -187,15 +187,22 @@
 	if cand.imp != nil && cand.imp.pkg != nil {
 		searchPkg = cand.imp.pkg
 	}
-	file, pkg, err := findPosInPackage(c.snapshot.View(), searchPkg, obj.Pos())
+
+	ph, pkg, err := findPosInPackage(c.snapshot.View(), searchPkg, obj.Pos())
 	if err != nil {
 		return item, nil
 	}
-	ident, err := findIdentifier(ctx, c.snapshot, pkg, file, obj.Pos())
+
+	posToDecl, err := ph.PosToDecl(ctx)
 	if err != nil {
+		return CompletionItem{}, err
+	}
+	decl := posToDecl[obj.Pos()]
+	if decl == nil {
 		return item, nil
 	}
-	hover, err := HoverIdentifier(ctx, ident)
+
+	hover, err := hoverInfo(pkg, obj, decl)
 	if err != nil {
 		event.Error(ctx, "failed to find Hover", err, tag.URI.Of(uri))
 		return item, nil
diff --git a/internal/lsp/source/hover.go b/internal/lsp/source/hover.go
index 8fc6974..66f1e20 100644
--- a/internal/lsp/source/hover.go
+++ b/internal/lsp/source/hover.go
@@ -102,10 +102,7 @@
 		h.SingleLine = objectString(obj, i.qf)
 	}
 	h.ImportPath, h.Link, h.SymbolName = pathLinkAndSymbolName(i)
-	if h.comment != nil {
-		h.FullDocumentation = h.comment.Text()
-		h.Synopsis = doc.Synopsis(h.FullDocumentation)
-	}
+
 	return h, nil
 }
 
@@ -217,13 +214,18 @@
 	_, done := event.Start(ctx, "source.hover")
 	defer done()
 
-	obj := d.obj
-	switch node := d.node.(type) {
+	return hoverInfo(pkg, d.obj, d.node)
+}
+
+func hoverInfo(pkg Package, obj types.Object, node ast.Node) (*HoverInformation, error) {
+	var info *HoverInformation
+
+	switch node := node.(type) {
 	case *ast.Ident:
 		// The package declaration.
 		for _, f := range pkg.GetSyntax() {
 			if f.Name == node {
-				return &HoverInformation{comment: f.Doc}, nil
+				info = &HoverInformation{comment: f.Doc}
 			}
 		}
 	case *ast.ImportSpec:
@@ -238,32 +240,47 @@
 			var doc *ast.CommentGroup
 			for _, file := range imp.GetSyntax() {
 				if file.Doc != nil {
-					return &HoverInformation{source: obj, comment: doc}, nil
+					info = &HoverInformation{source: obj, comment: doc}
 				}
 			}
 		}
-		return &HoverInformation{source: node}, nil
+		info = &HoverInformation{source: node}
 	case *ast.GenDecl:
 		switch obj := obj.(type) {
 		case *types.TypeName, *types.Var, *types.Const, *types.Func:
-			return formatGenDecl(node, obj, obj.Type())
+			var err error
+			info, err = formatGenDecl(node, obj, obj.Type())
+			if err != nil {
+				return nil, err
+			}
 		}
 	case *ast.TypeSpec:
 		if obj.Parent() == types.Universe {
 			if obj.Name() == "error" {
-				return &HoverInformation{source: node}, nil
+				info = &HoverInformation{source: node}
+			} else {
+				info = &HoverInformation{source: node.Name} // comments not needed for builtins
 			}
-			return &HoverInformation{source: node.Name}, nil // comments not needed for builtins
 		}
 	case *ast.FuncDecl:
 		switch obj.(type) {
 		case *types.Func:
-			return &HoverInformation{source: obj, comment: node.Doc}, nil
+			info = &HoverInformation{source: obj, comment: node.Doc}
 		case *types.Builtin:
-			return &HoverInformation{source: node.Type, comment: node.Doc}, nil
+			info = &HoverInformation{source: node.Type, comment: node.Doc}
 		}
 	}
-	return &HoverInformation{source: obj}, nil
+
+	if info == nil {
+		info = &HoverInformation{source: obj}
+	}
+
+	if info.comment != nil {
+		info.FullDocumentation = info.comment.Text()
+		info.Synopsis = doc.Synopsis(info.FullDocumentation)
+	}
+
+	return info, nil
 }
 
 func formatGenDecl(node *ast.GenDecl, obj types.Object, typ types.Type) (*HoverInformation, error) {
@@ -283,6 +300,7 @@
 	if spec == nil {
 		return nil, errors.Errorf("no spec for node %v at position %v", node, obj.Pos())
 	}
+
 	// If we have a field or method.
 	switch obj.(type) {
 	case *types.Var, *types.Const, *types.Func:
diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go
index 0c6d288..08e1050 100644
--- a/internal/lsp/source/identifier.go
+++ b/internal/lsp/source/identifier.go
@@ -12,7 +12,6 @@
 	"go/types"
 	"strconv"
 
-	"golang.org/x/tools/go/ast/astutil"
 	"golang.org/x/tools/internal/event"
 	"golang.org/x/tools/internal/lsp/protocol"
 	errors "golang.org/x/xerrors"
@@ -203,7 +202,7 @@
 	}
 	result.Declaration.MappedRange = append(result.Declaration.MappedRange, rng)
 
-	if result.Declaration.node, err = objToNode(s.View(), pkg, result.Declaration.obj); err != nil {
+	if result.Declaration.node, err = objToDecl(ctx, view, pkg, result.Declaration.obj); err != nil {
 		return nil, err
 	}
 	typ := pkg.GetTypesInfo().TypeOf(result.ident)
@@ -261,31 +260,18 @@
 	return types.IsInterface(obj.Type()) && obj.Pkg() == nil && obj.Name() == "error"
 }
 
-func objToNode(v View, pkg Package, obj types.Object) (ast.Decl, error) {
-	declAST, _, err := findPosInPackage(v, pkg, obj.Pos())
+func objToDecl(ctx context.Context, v View, srcPkg Package, obj types.Object) (ast.Decl, error) {
+	ph, _, err := findPosInPackage(v, srcPkg, obj.Pos())
 	if err != nil {
 		return nil, err
 	}
-	path, _ := astutil.PathEnclosingInterval(declAST, obj.Pos(), obj.Pos())
-	if path == nil {
-		return nil, errors.Errorf("no path for object %v", obj.Name())
+
+	posToDecl, err := ph.PosToDecl(ctx)
+	if err != nil {
+		return nil, err
 	}
-	for _, node := range path {
-		switch node := node.(type) {
-		case *ast.GenDecl:
-			// Type names, fields, and methods.
-			switch obj.(type) {
-			case *types.TypeName, *types.Var, *types.Const, *types.Func:
-				return node, nil
-			}
-		case *ast.FuncDecl:
-			// Function signatures.
-			if _, ok := obj.(*types.Func); ok {
-				return node, nil
-			}
-		}
-	}
-	return nil, nil // didn't find a node, but don't fail
+
+	return posToDecl[obj.Pos()], nil
 }
 
 // importSpec handles positions inside of an *ast.ImportSpec.
diff --git a/internal/lsp/source/signature_help.go b/internal/lsp/source/signature_help.go
index 69e5fc5..aef0a28 100644
--- a/internal/lsp/source/signature_help.go
+++ b/internal/lsp/source/signature_help.go
@@ -99,7 +99,7 @@
 		comment *ast.CommentGroup
 	)
 	if obj != nil {
-		node, err := objToNode(snapshot.View(), pkg, obj)
+		node, err := objToDecl(ctx, snapshot.View(), pkg, obj)
 		if err != nil {
 			return nil, 0, err
 		}
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index 1ec2430..d6af8a1 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -858,7 +858,7 @@
 		Signatures:      []protocol.SignatureInformation{*gotSignature},
 		ActiveParameter: float64(gotActiveParameter),
 	}
-	if diff := tests.DiffSignatures(spn, got, want); diff != "" {
+	if diff := tests.DiffSignatures(spn, want, got); diff != "" {
 		t.Error(diff)
 	}
 }
diff --git a/internal/lsp/source/types_format.go b/internal/lsp/source/types_format.go
index 53323f9..d07ee73 100644
--- a/internal/lsp/source/types_format.go
+++ b/internal/lsp/source/types_format.go
@@ -195,21 +195,16 @@
 // To do this, it looks in the AST of the file in which the object is declared.
 // On any errors, it always fallbacks back to types.TypeString.
 func formatVarType(ctx context.Context, s Snapshot, srcpkg Package, srcfile *ast.File, obj *types.Var, qf types.Qualifier) string {
-	file, pkg, err := findPosInPackage(s.View(), srcpkg, obj.Pos())
+	ph, pkg, err := findPosInPackage(s.View(), srcpkg, obj.Pos())
 	if err != nil {
 		return types.TypeString(obj.Type(), qf)
 	}
-	// Named and unnamed variables must be handled differently.
-	// Unnamed variables appear in the result values of a function signature.
-	var expr ast.Expr
-	if obj.Name() != "" {
-		expr, err = namedVarType(ctx, s, pkg, file, obj)
-	} else {
-		expr, err = unnamedVarType(file, obj)
-	}
+
+	expr, err := varType(ctx, ph, obj)
 	if err != nil {
 		return types.TypeString(obj.Type(), qf)
 	}
+
 	// The type names in the AST may not be correctly qualified.
 	// Determine the package name to use based on the package that originated
 	// the query and the package in which the type is declared.
@@ -224,43 +219,19 @@
 	return fmted
 }
 
-// unnamedVarType finds the type for an unnamed variable.
-func unnamedVarType(file *ast.File, obj *types.Var) (ast.Expr, error) {
-	path, _ := astutil.PathEnclosingInterval(file, obj.Pos(), obj.Pos())
-	var expr ast.Expr
-	for _, p := range path {
-		e, ok := p.(ast.Expr)
-		if !ok {
-			break
-		}
-		expr = e
-	}
-	typ, ok := expr.(ast.Expr)
-	if !ok {
-		return nil, fmt.Errorf("unexpected type for node (%T)", path[0])
-	}
-	return typ, nil
-}
-
-// namedVarType returns the type for a named variable.
-func namedVarType(ctx context.Context, s Snapshot, pkg Package, file *ast.File, obj *types.Var) (ast.Expr, error) {
-	ident, err := findIdentifier(ctx, s, pkg, file, obj.Pos())
+// varType returns the type expression for a *types.Var.
+func varType(ctx context.Context, ph ParseGoHandle, obj *types.Var) (ast.Expr, error) {
+	posToField, err := ph.PosToField(ctx)
 	if err != nil {
 		return nil, err
 	}
-	if ident.Declaration.obj != obj {
-		return nil, fmt.Errorf("expected the ident's declaration %v to be equal to obj %v", ident.Declaration.obj, obj)
-	}
-	if i := ident.ident; i == nil || i.Obj == nil || i.Obj.Decl == nil {
+	field := posToField[obj.Pos()]
+	if field == nil {
 		return nil, fmt.Errorf("no declaration for object %s", obj.Name())
 	}
-	f, ok := ident.ident.Obj.Decl.(*ast.Field)
+	typ, ok := field.Type.(ast.Expr)
 	if !ok {
-		return nil, fmt.Errorf("declaration of object %v is %T, not *ast.Field", obj.Name(), ident.ident.Obj.Decl)
-	}
-	typ, ok := f.Type.(ast.Expr)
-	if !ok {
-		return nil, fmt.Errorf("unexpected type for node (%T)", f.Type)
+		return nil, fmt.Errorf("unexpected type for node (%T)", field.Type)
 	}
 	return typ, nil
 }
diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go
index ca0df19..6afe7d7 100644
--- a/internal/lsp/source/util.go
+++ b/internal/lsp/source/util.go
@@ -529,7 +529,7 @@
 	return 1
 }
 
-func findPosInPackage(v View, searchpkg Package, pos token.Pos) (*ast.File, Package, error) {
+func findPosInPackage(v View, searchpkg Package, pos token.Pos) (ParseGoHandle, Package, error) {
 	tok := v.Session().Cache().FileSet().File(pos)
 	if tok == nil {
 		return nil, nil, errors.Errorf("no file for pos in package %s", searchpkg.ID())
@@ -540,14 +540,7 @@
 	if err != nil {
 		return nil, nil, err
 	}
-	file, _, _, _, err := ph.Cached()
-	if err != nil {
-		return nil, nil, err
-	}
-	if !(file.Pos() <= pos && pos <= file.End()) {
-		return nil, nil, fmt.Errorf("pos %v, apparently in file %q, is not between %v and %v", pos, ph.File().URI(), file.Pos(), file.End())
-	}
-	return file, pkg, nil
+	return ph, pkg, nil
 }
 
 func findMapperInPackage(v View, searchpkg Package, uri span.URI) (*protocol.ColumnMapper, error) {
diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go
index 92230e4..fb6e3f5 100644
--- a/internal/lsp/source/view.go
+++ b/internal/lsp/source/view.go
@@ -303,6 +303,17 @@
 
 	// Cached returns the AST for this handle, if it has already been stored.
 	Cached() (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error)
+
+	// PosToField is a cache of *ast.Fields by token.Pos. This allows us
+	// to quickly find corresponding *ast.Field node given a *types.Var.
+	// We must refer to the AST to render type aliases properly when
+	// formatting signatures and other types.
+	PosToField(context.Context) (map[token.Pos]*ast.Field, error)
+
+	// PosToDecl maps certain objects' positions to their surrounding
+	// ast.Decl. This mapping is used when building the documentation
+	// string for the objects.
+	PosToDecl(context.Context) (map[token.Pos]ast.Decl, error)
 }
 
 type ParseModHandle interface {
diff --git a/internal/lsp/testdata/lsp/primarymod/deep/deep.go b/internal/lsp/testdata/lsp/primarymod/deep/deep.go
index 09dd1b8..08f18b3 100644
--- a/internal/lsp/testdata/lsp/primarymod/deep/deep.go
+++ b/internal/lsp/testdata/lsp/primarymod/deep/deep.go
@@ -34,7 +34,7 @@
 		*deepCircle
 	}
 	var circle deepCircle   //@item(deepCircle, "circle", "deepCircle", "var")
-	*circle.deepCircle      //@item(deepCircleField, "*circle.deepCircle", "*deepCircle", "field", "deepCircle is circular.")
+	*circle.deepCircle      //@item(deepCircleField, "*circle.deepCircle", "*deepCircle", "field")
 	var _ deepCircle = circ //@deep(" //", deepCircle, deepCircleField)
 }