internal/lsp: trim ASTs for which we do not require function bodies

This change trims the function bodies from the ASTs of files belonging to
dependency packages. In these cases, we do not necessarily need full
file ASTs, so it's not necessary to store the function bodies in memory.

This change will reduce memory usage. However, it will also slow down
the case of a user opening a file in a dependency package, as we will
have to re-typecheck the file to get the full AST. Hopefully, this
increase in latency will not be significant, as we will only need to
re-typecheck a single package (all the dependencies should be cached).

Updates golang/go#30309

Change-Id: I7871ae44499c851d1097087bd9d3567bb27db691
Reviewed-on: https://go-review.googlesource.com/c/tools/+/178719
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go
index e6feadd..b31178a 100644
--- a/internal/lsp/cache/check.go
+++ b/internal/lsp/cache/check.go
@@ -24,6 +24,9 @@
 	// If we have seen a package that is already in this map, we have a circular import.
 	seen map[string]struct{}
 
+	// topLevelPkgPath is the path of the package from which type-checking began.
+	topLevelPkgPath string
+
 	ctx  context.Context
 	fset *token.FileSet
 }
@@ -96,7 +99,9 @@
 	appendError := func(err error) {
 		imp.view.appendPkgError(pkg, err)
 	}
-	files, errs := imp.parseFiles(meta.files)
+
+	// Don't type-check function bodies if we are not in the top-level package.
+	files, errs := imp.parseFiles(meta.files, imp.ignoreFuncBodies(pkg.pkgPath))
 	for _, err := range errs {
 		appendError(err)
 	}
@@ -112,10 +117,11 @@
 	cfg := &types.Config{
 		Error: appendError,
 		Importer: &importer{
-			view: imp.view,
-			seen: seen,
-			ctx:  imp.ctx,
-			fset: imp.fset,
+			view:            imp.view,
+			ctx:             imp.ctx,
+			fset:            imp.fset,
+			topLevelPkgPath: imp.topLevelPkgPath,
+			seen:            seen,
 		},
 	}
 	check := types.NewChecker(cfg, imp.fset, pkg.types, pkg.typesInfo)
@@ -151,8 +157,11 @@
 			continue
 		}
 		gof.token = tok
-		gof.ast = file
-		gof.imports = gof.ast.Imports
+		gof.ast = &astFile{
+			file:      file,
+			isTrimmed: imp.ignoreFuncBodies(pkg.pkgPath),
+		}
+		gof.imports = file.Imports
 		gof.pkg = pkg
 	}
 
@@ -198,3 +207,7 @@
 	}
 	pkg.errors = append(pkg.errors, errs...)
 }
+
+func (imp *importer) ignoreFuncBodies(pkgPath string) bool {
+	return imp.topLevelPkgPath != pkgPath
+}
diff --git a/internal/lsp/cache/gofile.go b/internal/lsp/cache/gofile.go
index 4a8c0d1..ac6d5c4 100644
--- a/internal/lsp/cache/gofile.go
+++ b/internal/lsp/cache/gofile.go
@@ -13,15 +13,22 @@
 type goFile struct {
 	fileBase
 
-	ast     *ast.File
+	ast *astFile
+
 	pkg     *pkg
 	meta    *metadata
 	imports []*ast.ImportSpec
 }
 
+type astFile struct {
+	file      *ast.File
+	isTrimmed bool
+}
+
 func (f *goFile) GetToken(ctx context.Context) *token.File {
 	f.view.mu.Lock()
 	defer f.view.mu.Unlock()
+
 	if f.isDirty() {
 		if _, err := f.view.loadParseTypecheck(ctx, f); err != nil {
 			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
@@ -31,7 +38,7 @@
 	return f.token
 }
 
-func (f *goFile) GetAST(ctx context.Context) *ast.File {
+func (f *goFile) GetTrimmedAST(ctx context.Context) *ast.File {
 	f.view.mu.Lock()
 	defer f.view.mu.Unlock()
 
@@ -41,14 +48,27 @@
 			return nil
 		}
 	}
-	return f.ast
+	return f.ast.file
+}
+
+func (f *goFile) GetAST(ctx context.Context) *ast.File {
+	f.view.mu.Lock()
+	defer f.view.mu.Unlock()
+
+	if f.isDirty() || f.astIsTrimmed() {
+		if _, err := f.view.loadParseTypecheck(ctx, f); err != nil {
+			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
+			return nil
+		}
+	}
+	return f.ast.file
 }
 
 func (f *goFile) GetPackage(ctx context.Context) source.Package {
 	f.view.mu.Lock()
 	defer f.view.mu.Unlock()
 
-	if f.isDirty() {
+	if f.isDirty() || f.astIsTrimmed() {
 		if errs, err := f.view.loadParseTypecheck(ctx, f); err != nil {
 			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
 
@@ -68,6 +88,10 @@
 	return f.meta == nil || f.imports == nil || f.token == nil || f.ast == nil || f.pkg == nil || len(f.view.contentChanges) > 0
 }
 
+func (f *goFile) astIsTrimmed() bool {
+	return f.ast != nil && f.ast.isTrimmed
+}
+
 func (f *goFile) GetActiveReverseDeps(ctx context.Context) []source.GoFile {
 	pkg := f.GetPackage(ctx)
 	if pkg == nil {
diff --git a/internal/lsp/cache/load.go b/internal/lsp/cache/load.go
index aebe5f2..772c350 100644
--- a/internal/lsp/cache/load.go
+++ b/internal/lsp/cache/load.go
@@ -23,20 +23,24 @@
 	if !f.isDirty() {
 		return nil, nil
 	}
-	// Check if the file's imports have changed. If they have, update the
-	// metadata by calling packages.Load.
+
+	// Check if we need to run go/packages.Load for this file's package.
 	if errs, err := v.checkMetadata(ctx, f); err != nil {
 		return errs, err
 	}
+
 	if f.meta == nil {
-		return nil, fmt.Errorf("no metadata found for %v", f.filename())
+		return nil, fmt.Errorf("loadParseTypecheck: no metadata found for %v", f.filename())
 	}
+
 	imp := &importer{
-		view: v,
-		seen: make(map[string]struct{}),
-		ctx:  ctx,
-		fset: f.FileSet(),
+		view:            v,
+		seen:            make(map[string]struct{}),
+		ctx:             ctx,
+		fset:            f.FileSet(),
+		topLevelPkgPath: f.meta.pkgPath,
 	}
+
 	// Start prefetching direct imports.
 	for importPath := range f.meta.children {
 		go imp.Import(importPath)
@@ -53,37 +57,40 @@
 	return nil, nil
 }
 
+// checkMetadata determines if we should run go/packages.Load for this file.
+// If yes, update the metadata for the file and its package.
 func (v *view) checkMetadata(ctx context.Context, f *goFile) ([]packages.Error, error) {
-	if v.reparseImports(ctx, f, f.filename()) {
-		cfg := v.buildConfig()
-		pkgs, err := packages.Load(cfg, fmt.Sprintf("file=%s", f.filename()))
-		if len(pkgs) == 0 {
-			if err == nil {
-				err = fmt.Errorf("%s: no packages found", f.filename())
-			}
-			// Return this error as a diagnostic to the user.
-			return []packages.Error{
-				{
-					Msg:  err.Error(),
-					Kind: packages.ListError,
-				},
-			}, err
+	if !v.reparseImports(ctx, f) {
+		return nil, nil
+	}
+	pkgs, err := packages.Load(v.buildConfig(), fmt.Sprintf("file=%s", f.filename()))
+	if len(pkgs) == 0 {
+		if err == nil {
+			err = fmt.Errorf("%s: no packages found", f.filename())
 		}
-		for _, pkg := range pkgs {
-			// If the package comes back with errors from `go list`, don't bother
-			// type-checking it.
-			if len(pkg.Errors) > 0 {
-				return pkg.Errors, fmt.Errorf("package %s has errors, skipping type-checking", pkg.PkgPath)
-			}
-			v.link(ctx, pkg.PkgPath, pkg, nil)
+		// Return this error as a diagnostic to the user.
+		return []packages.Error{
+			{
+				Msg:  err.Error(),
+				Kind: packages.ListError,
+			},
+		}, err
+	}
+	for _, pkg := range pkgs {
+		// If the package comes back with errors from `go list`,
+		// don't bother type-checking it.
+		if len(pkg.Errors) > 0 {
+			return pkg.Errors, fmt.Errorf("package %s has errors, skipping type-checking", pkg.PkgPath)
 		}
+		// Build the import graph for this package.
+		v.link(ctx, pkg.PkgPath, pkg, nil)
 	}
 	return nil, nil
 }
 
-// reparseImports reparses a file's import declarations to determine if they
-// have changed.
-func (v *view) reparseImports(ctx context.Context, f *goFile, filename string) bool {
+// reparseImports reparses a file's package and import declarations to
+// determine if they have changed.
+func (v *view) reparseImports(ctx context.Context, f *goFile) bool {
 	if f.meta == nil {
 		return true
 	}
@@ -92,7 +99,7 @@
 	if f.fc.Error != nil {
 		return true
 	}
-	parsed, _ := parser.ParseFile(f.FileSet(), filename, f.fc.Data, parser.ImportsOnly)
+	parsed, _ := parser.ParseFile(f.FileSet(), f.filename(), f.fc.Data, parser.ImportsOnly)
 	if parsed == nil {
 		return true
 	}
@@ -129,12 +136,11 @@
 	m.files = pkg.CompiledGoFiles
 	for _, filename := range m.files {
 		if f, _ := v.getFile(span.FileURI(filename)); f != nil {
-			gof, ok := f.(*goFile)
-			if !ok {
-				v.Session().Logger().Errorf(ctx, "not a go file: %v", f.URI())
-				continue
+			if gof, ok := f.(*goFile); ok {
+				gof.meta = m
+			} else {
+				v.Session().Logger().Errorf(ctx, "not a Go file: %s", f.URI())
 			}
-			gof.meta = m
 		}
 	}
 	// Connect the import graph.
diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go
index 4f8cd35..84ce104 100644
--- a/internal/lsp/cache/parse.go
+++ b/internal/lsp/cache/parse.go
@@ -34,7 +34,7 @@
 // Because files are scanned in parallel, the token.Pos
 // positions of the resulting ast.Files are not ordered.
 //
-func (imp *importer) parseFiles(filenames []string) ([]*ast.File, []error) {
+func (imp *importer) parseFiles(filenames []string, ignoreFuncBodies bool) ([]*ast.File, []error) {
 	var wg sync.WaitGroup
 	n := len(filenames)
 	parsed := make([]*ast.File, n)
@@ -54,7 +54,7 @@
 		}
 		gof, ok := f.(*goFile)
 		if !ok {
-			parsed[i], errors[i] = nil, fmt.Errorf("Non go file in parse call: %v", filename)
+			parsed[i], errors[i] = nil, fmt.Errorf("non-Go file in parse call: %v", filename)
 			continue
 		}
 
@@ -66,24 +66,25 @@
 				wg.Done()
 			}()
 
-			if gof.ast != nil { // already have an ast
-				parsed[i], errors[i] = gof.ast, nil
+			// If we already have a cached AST, reuse it.
+			// If the AST is trimmed, only use it if we are ignoring function bodies.
+			if gof.ast != nil && (!gof.ast.isTrimmed || ignoreFuncBodies) {
+				parsed[i], errors[i] = gof.ast.file, nil
 				return
 			}
 
-			// No cached AST for this file, so try parsing it.
+			// We don't have a cached AST for this file, so we read its content and parse it.
 			gof.read(imp.ctx)
-			if gof.fc.Error != nil { // file content error, so abort
+			if gof.fc.Error != nil {
 				return
 			}
-
 			src := gof.fc.Data
-			if src == nil { // no source
-				parsed[i], errors[i] = nil, fmt.Errorf("No source for %v", filename)
+			if src == nil {
+				parsed[i], errors[i] = nil, fmt.Errorf("no source for %v", filename)
 				return
 			}
 
-			// ParseFile may return a partial AST AND an error.
+			// ParseFile may return a partial AST and an error.
 			parsed[i], errors[i] = parseFile(imp.fset, filename, src)
 
 			// Fix any badly parsed parts of the AST.
@@ -140,8 +141,44 @@
 	return false
 }
 
-// fix inspects and potentially modifies any *ast.BadStmts or *ast.BadExprs in the AST.
+// trimAST clears any part of the AST not relevant to type checking
+// expressions at pos.
+func trimAST(file *ast.File) {
+	ast.Inspect(file, func(n ast.Node) bool {
+		if n == nil {
+			return false
+		}
+		switch n := n.(type) {
+		case *ast.FuncDecl:
+			n.Body = nil
+		case *ast.BlockStmt:
+			n.List = nil
+		case *ast.CaseClause:
+			n.Body = nil
+		case *ast.CommClause:
+			n.Body = nil
+		case *ast.CompositeLit:
+			// Leave elts in place for [...]T
+			// array literals, because they can
+			// affect the expression's type.
+			if !isEllipsisArray(n.Type) {
+				n.Elts = nil
+			}
+		}
+		return true
+	})
+}
 
+func isEllipsisArray(n ast.Expr) bool {
+	at, ok := n.(*ast.ArrayType)
+	if !ok {
+		return false
+	}
+	_, ok = at.Len.(*ast.Ellipsis)
+	return ok
+}
+
+// fix inspects and potentially modifies any *ast.BadStmts or *ast.BadExprs in the AST.
 // We attempt to modify the AST such that we can type-check it more effectively.
 func (v *view) fix(ctx context.Context, file *ast.File, tok *token.File, src []byte) {
 	var parent ast.Node
diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go
index affa130..02df6b0 100644
--- a/internal/lsp/source/identifier.go
+++ b/internal/lsp/source/identifier.go
@@ -181,9 +181,15 @@
 	}
 	declFile, ok := f.(GoFile)
 	if !ok {
-		return nil, fmt.Errorf("not a go file %v", s.URI())
+		return nil, fmt.Errorf("not a Go file %v", s.URI())
 	}
-	declAST := declFile.GetAST(ctx)
+	// If the object is exported, we don't need the full AST to find its definition.
+	var declAST *ast.File
+	if obj.Exported() {
+		declAST = declFile.GetTrimmedAST(ctx)
+	} else {
+		declAST = declFile.GetAST(ctx)
+	}
 	path, _ := astutil.PathEnclosingInterval(declAST, rng.Start, rng.End)
 	if path == nil {
 		return nil, fmt.Errorf("no path for range %v", rng)
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index fa22091..67ed7e3 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -32,7 +32,6 @@
 }
 
 func testSource(t *testing.T, exporter packagestest.Exporter) {
-	ctx := context.Background()
 	data := tests.Load(t, exporter, "../testdata")
 	defer data.Exported.Cleanup()
 
@@ -45,7 +44,7 @@
 	}
 	r.view.SetEnv(data.Config.Env)
 	for filename, content := range data.Config.Overlay {
-		r.view.SetContent(ctx, span.FileURI(filename), content)
+		session.SetOverlay(span.FileURI(filename), content)
 	}
 	tests.Run(t, r, data)
 }
diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go
index e00cff8..9091646 100644
--- a/internal/lsp/source/view.go
+++ b/internal/lsp/source/view.go
@@ -148,7 +148,15 @@
 // GoFile represents a Go source file that has been type-checked.
 type GoFile interface {
 	File
+
+	// GetTrimmedAST returns an AST that may or may not contain function bodies.
+	// It should be used in scenarios where function bodies are not necessary.
+	GetTrimmedAST(ctx context.Context) *ast.File
+
+	// GetAST returns the full AST for the file.
 	GetAST(ctx context.Context) *ast.File
+
+	// GetPackage returns the package that this file belongs to.
 	GetPackage(ctx context.Context) Package
 
 	// GetActiveReverseDeps returns the active files belonging to the reverse