internal/lsp: fix some issues with trimming ASTs

This change correctly invalidates the cache when we
have to go from a trimmed to untrimmed AST.

The "ignoreFuncBodies" behavior is still disabled due to a racy test.

Updates golang/go#30309

Change-Id: I6b89d1d2140d77517616cb3956721a157c25ab71
Reviewed-on: https://go-review.googlesource.com/c/tools/+/180857
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go
index b31178a..45b7bc8 100644
--- a/internal/lsp/cache/check.go
+++ b/internal/lsp/cache/check.go
@@ -24,8 +24,8 @@
 	// If we have seen a package that is already in this map, we have a circular import.
 	seen map[string]struct{}
 
-	// topLevelPkgPath is the path of the package from which type-checking began.
-	topLevelPkgPath string
+	// topLevelPkgID is the ID of the package from which type-checking began.
+	topLevelPkgID string
 
 	ctx  context.Context
 	fset *token.FileSet
@@ -100,8 +100,12 @@
 		imp.view.appendPkgError(pkg, err)
 	}
 
+	// Ignore function bodies for any dependency packages.
+	// TODO: Enable this.
+	ignoreFuncBodies := false
+
 	// Don't type-check function bodies if we are not in the top-level package.
-	files, errs := imp.parseFiles(meta.files, imp.ignoreFuncBodies(pkg.pkgPath))
+	files, errs := imp.parseFiles(meta.files, ignoreFuncBodies)
 	for _, err := range errs {
 		appendError(err)
 	}
@@ -115,17 +119,18 @@
 	seen[pkgPath] = struct{}{}
 
 	cfg := &types.Config{
-		Error: appendError,
+		Error:            appendError,
+		IgnoreFuncBodies: ignoreFuncBodies,
 		Importer: &importer{
-			view:            imp.view,
-			ctx:             imp.ctx,
-			fset:            imp.fset,
-			topLevelPkgPath: imp.topLevelPkgPath,
-			seen:            seen,
+			view:          imp.view,
+			ctx:           imp.ctx,
+			fset:          imp.fset,
+			topLevelPkgID: imp.topLevelPkgID,
+			seen:          seen,
 		},
 	}
 	check := types.NewChecker(cfg, imp.fset, pkg.types, pkg.typesInfo)
-	check.Files(pkg.syntax)
+	check.Files(pkg.GetSyntax())
 
 	// Add every file in this package to our cache.
 	imp.cachePackage(imp.ctx, pkg, meta)
@@ -134,15 +139,15 @@
 }
 
 func (imp *importer) cachePackage(ctx context.Context, pkg *pkg, meta *metadata) {
-	for _, file := range pkg.GetSyntax() {
+	for _, fAST := range pkg.syntax {
 		// TODO: If a file is in multiple packages, which package do we store?
-		if !file.Pos().IsValid() {
-			imp.view.Session().Logger().Errorf(ctx, "invalid position for file %v", file.Name)
+		if !fAST.file.Pos().IsValid() {
+			imp.view.Session().Logger().Errorf(ctx, "invalid position for file %v", fAST.file.Name)
 			continue
 		}
-		tok := imp.view.Session().Cache().FileSet().File(file.Pos())
+		tok := imp.view.Session().Cache().FileSet().File(fAST.file.Pos())
 		if tok == nil {
-			imp.view.Session().Logger().Errorf(ctx, "no token.File for %v", file.Name)
+			imp.view.Session().Logger().Errorf(ctx, "no token.File for %v", fAST.file.Name)
 			continue
 		}
 		fURI := span.FileURI(tok.Name())
@@ -153,15 +158,12 @@
 		}
 		gof, ok := f.(*goFile)
 		if !ok {
-			imp.view.Session().Logger().Errorf(ctx, "not a go file: %v", f.URI())
+			imp.view.Session().Logger().Errorf(ctx, "%v is not a Go file", f.URI())
 			continue
 		}
 		gof.token = tok
-		gof.ast = &astFile{
-			file:      file,
-			isTrimmed: imp.ignoreFuncBodies(pkg.pkgPath),
-		}
-		gof.imports = file.Imports
+		gof.ast = fAST
+		gof.imports = fAST.file.Imports
 		gof.pkg = pkg
 	}
 
@@ -207,7 +209,3 @@
 	}
 	pkg.errors = append(pkg.errors, errs...)
 }
-
-func (imp *importer) ignoreFuncBodies(pkgPath string) bool {
-	return imp.topLevelPkgPath != pkgPath
-}
diff --git a/internal/lsp/cache/gofile.go b/internal/lsp/cache/gofile.go
index 350c84e..5fcf170 100644
--- a/internal/lsp/cache/gofile.go
+++ b/internal/lsp/cache/gofile.go
@@ -29,16 +29,19 @@
 	f.view.mu.Lock()
 	defer f.view.mu.Unlock()
 
-	if f.isDirty() {
+	if f.isDirty() || f.astIsTrimmed() {
 		if _, err := f.view.loadParseTypecheck(ctx, f); err != nil {
 			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
 			return nil
 		}
 	}
+	if unexpectedAST(ctx, f) {
+		return nil
+	}
 	return f.token
 }
 
-func (f *goFile) GetTrimmedAST(ctx context.Context) *ast.File {
+func (f *goFile) GetAnyAST(ctx context.Context) *ast.File {
 	f.view.mu.Lock()
 	defer f.view.mu.Unlock()
 
@@ -48,6 +51,9 @@
 			return nil
 		}
 	}
+	if f.ast == nil {
+		return nil
+	}
 	return f.ast.file
 }
 
@@ -61,6 +67,9 @@
 			return nil
 		}
 	}
+	if unexpectedAST(ctx, f) {
+		return nil
+	}
 	return f.ast.file
 }
 
@@ -79,13 +88,30 @@
 			return nil
 		}
 	}
+	if unexpectedAST(ctx, f) {
+		return nil
+	}
 	return f.pkg
 }
 
+func unexpectedAST(ctx context.Context, f *goFile) bool {
+	// If the AST comes back nil, something has gone wrong.
+	if f.ast == nil {
+		f.View().Session().Logger().Errorf(ctx, "expected full AST for %s, returned nil", f.URI())
+		return true
+	}
+	// If the AST comes back trimmed, something has gone wrong.
+	if f.astIsTrimmed() {
+		f.View().Session().Logger().Errorf(ctx, "expected full AST for %s, returned trimmed", f.URI())
+		return true
+	}
+	return false
+}
+
 // isDirty is true if the file needs to be type-checked.
 // It assumes that the file's view's mutex is held by the caller.
 func (f *goFile) isDirty() bool {
-	return f.meta == nil || f.imports == nil || f.token == nil || f.ast == nil || f.pkg == nil
+	return f.meta == nil || f.token == nil || f.ast == nil || f.pkg == nil
 }
 
 func (f *goFile) astIsTrimmed() bool {
diff --git a/internal/lsp/cache/load.go b/internal/lsp/cache/load.go
index 9a8ccdb..6e669ce 100644
--- a/internal/lsp/cache/load.go
+++ b/internal/lsp/cache/load.go
@@ -13,27 +13,26 @@
 	v.mcache.mu.Lock()
 	defer v.mcache.mu.Unlock()
 
-	// If the package for the file has not been invalidated by the application
-	// of the pending changes, there is no need to continue.
-	if !f.isDirty() {
-		return nil, nil
+	// If the AST for this file is trimmed, and we are explicitly type-checking it,
+	// don't ignore function bodies.
+	if f.astIsTrimmed() {
+		f.invalidateAST()
 	}
 
 	// Check if we need to run go/packages.Load for this file's package.
 	if errs, err := v.checkMetadata(ctx, f); err != nil {
 		return errs, err
 	}
-
 	if f.meta == nil {
 		return nil, fmt.Errorf("loadParseTypecheck: no metadata found for %v", f.filename())
 	}
 
 	imp := &importer{
-		view:            v,
-		seen:            make(map[string]struct{}),
-		ctx:             ctx,
-		fset:            f.FileSet(),
-		topLevelPkgPath: f.meta.pkgPath,
+		view:          v,
+		seen:          make(map[string]struct{}),
+		ctx:           ctx,
+		fset:          f.FileSet(),
+		topLevelPkgID: f.meta.id,
 	}
 
 	// Start prefetching direct imports.
@@ -47,7 +46,7 @@
 	}
 	// If we still have not found the package for the file, something is wrong.
 	if f.pkg == nil {
-		return nil, fmt.Errorf("parse: no package found for %v", f.filename())
+		return nil, fmt.Errorf("loadParseTypeCheck: no package found for %v", f.filename())
 	}
 	return nil, nil
 }
@@ -61,7 +60,7 @@
 	pkgs, err := packages.Load(v.buildConfig(), fmt.Sprintf("file=%s", f.filename()))
 	if len(pkgs) == 0 {
 		if err == nil {
-			err = fmt.Errorf("%s: no packages found", f.filename())
+			err = fmt.Errorf("no packages found for %s", f.filename())
 		}
 		// Return this error as a diagnostic to the user.
 		return []packages.Error{
diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go
index 8795883..1b41b86 100644
--- a/internal/lsp/cache/parse.go
+++ b/internal/lsp/cache/parse.go
@@ -34,10 +34,10 @@
 // Because files are scanned in parallel, the token.Pos
 // positions of the resulting ast.Files are not ordered.
 //
-func (imp *importer) parseFiles(filenames []string, ignoreFuncBodies bool) ([]*ast.File, []error) {
+func (imp *importer) parseFiles(filenames []string, ignoreFuncBodies bool) ([]*astFile, []error) {
 	var wg sync.WaitGroup
 	n := len(filenames)
-	parsed := make([]*ast.File, n)
+	parsed := make([]*astFile, n)
 	errors := make([]error, n)
 	for i, filename := range filenames {
 		if imp.ctx.Err() != nil {
@@ -68,8 +68,11 @@
 
 			// If we already have a cached AST, reuse it.
 			// If the AST is trimmed, only use it if we are ignoring function bodies.
-			if gof.ast != nil && (!gof.ast.isTrimmed || ignoreFuncBodies) {
-				parsed[i], errors[i] = gof.ast.file, nil
+			if gof.astIsTrimmed() && ignoreFuncBodies {
+				parsed[i], errors[i] = gof.ast, nil
+				return
+			} else if gof.ast != nil && !gof.ast.isTrimmed && !ignoreFuncBodies {
+				parsed[i], errors[i] = gof.ast, nil
 				return
 			}
 
@@ -85,13 +88,21 @@
 			}
 
 			// ParseFile may return a partial AST and an error.
-			parsed[i], errors[i] = parseFile(imp.fset, filename, src)
+			f, err := parseFile(imp.fset, filename, src)
+
+			if ignoreFuncBodies {
+				trimAST(f)
+			}
 
 			// Fix any badly parsed parts of the AST.
-			if file := parsed[i]; file != nil {
-				tok := imp.fset.File(file.Pos())
-				imp.view.fix(imp.ctx, parsed[i], tok, src)
+			if f != nil {
+				tok := imp.fset.File(f.Pos())
+				imp.view.fix(imp.ctx, f, tok, src)
 			}
+
+			parsed[i] = &astFile{f, ignoreFuncBodies}
+			errors[i] = err
+
 		}(i, filename)
 	}
 	wg.Wait()
diff --git a/internal/lsp/cache/pkg.go b/internal/lsp/cache/pkg.go
index 9cf86b4..caf1ee5 100644
--- a/internal/lsp/cache/pkg.go
+++ b/internal/lsp/cache/pkg.go
@@ -20,7 +20,7 @@
 type pkg struct {
 	id, pkgPath string
 	files       []string
-	syntax      []*ast.File
+	syntax      []*astFile
 	errors      []packages.Error
 	imports     map[string]*pkg
 	types       *types.Package
@@ -137,7 +137,11 @@
 }
 
 func (pkg *pkg) GetSyntax() []*ast.File {
-	return pkg.syntax
+	syntax := make([]*ast.File, len(pkg.syntax))
+	for i := range pkg.syntax {
+		syntax[i] = pkg.syntax[i].file
+	}
+	return syntax
 }
 
 func (pkg *pkg) GetErrors() []packages.Error {
diff --git a/internal/lsp/cache/view.go b/internal/lsp/cache/view.go
index 06c3c2a..9d2ee15 100644
--- a/internal/lsp/cache/view.go
+++ b/internal/lsp/cache/view.go
@@ -222,10 +222,12 @@
 	return nil
 }
 
-func (f *goFile) invalidate() {
+// invalidateContent invalidates the content of a Go file,
+// including any position and type information that depends on it.
+func (f *goFile) invalidateContent() {
 	f.view.pcache.mu.Lock()
 	defer f.view.pcache.mu.Unlock()
-	// TODO(rstambler): Should we recompute these here?
+
 	f.ast = nil
 	f.token = nil
 
@@ -236,6 +238,21 @@
 	f.handle = nil
 }
 
+// invalidateAST invalidates the AST of a Go file,
+// including any position and type information that depends on it.
+func (f *goFile) invalidateAST() {
+	f.view.pcache.mu.Lock()
+	defer f.view.pcache.mu.Unlock()
+
+	f.ast = nil
+	f.token = nil
+
+	// Remove the package and all of its reverse dependencies from the cache.
+	if f.pkg != nil {
+		f.view.remove(f.pkg.pkgPath, map[string]struct{}{})
+	}
+}
+
 // remove invalidates a package and its reverse dependencies in the view's
 // package cache. It is assumed that the caller has locked both the mutexes
 // of both the mcache and the pcache.
@@ -308,7 +325,7 @@
 			},
 		}
 		v.session.filesWatchMap.Watch(uri, func() {
-			f.(*goFile).invalidate()
+			f.(*goFile).invalidateContent()
 		})
 	case ".mod":
 		f = &modFile{
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index f90f76a..2548b46 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -134,11 +134,11 @@
 	fmt.Fprintf(msg, reason, args...)
 	fmt.Fprint(msg, ":\nexpected:\n")
 	for _, d := range want {
-		fmt.Fprintf(msg, "  %v\n", d)
+		fmt.Fprintf(msg, "  %v: %s\n", d.Span, d.Message)
 	}
 	fmt.Fprintf(msg, "got:\n")
 	for _, d := range got {
-		fmt.Fprintf(msg, "  %v\n", d)
+		fmt.Fprintf(msg, "  %v: %s\n", d.Span, d.Message)
 	}
 	return msg.String()
 }
diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go
index 06c7e0d..7d3a8e7 100644
--- a/internal/lsp/source/identifier.go
+++ b/internal/lsp/source/identifier.go
@@ -134,7 +134,7 @@
 	if result.decl.rng, err = objToRange(ctx, f.FileSet(), result.decl.obj); err != nil {
 		return nil, err
 	}
-	if result.decl.node, err = objToNode(ctx, v, result.decl.obj, result.decl.rng); err != nil {
+	if result.decl.node, err = objToNode(ctx, v, pkg.GetTypes(), result.decl.obj, result.decl.rng); err != nil {
 		return nil, err
 	}
 	typ := pkg.GetTypesInfo().TypeOf(result.ident)
@@ -180,7 +180,7 @@
 	return span.NewRange(fset, pos, pos+token.Pos(len(name))), nil
 }
 
-func objToNode(ctx context.Context, v View, obj types.Object, rng span.Range) (ast.Decl, error) {
+func objToNode(ctx context.Context, v View, originPkg *types.Package, obj types.Object, rng span.Range) (ast.Decl, error) {
 	s, err := rng.Span()
 	if err != nil {
 		return nil, err
@@ -191,12 +191,13 @@
 	}
 	declFile, ok := f.(GoFile)
 	if !ok {
-		return nil, fmt.Errorf("not a Go file %v", s.URI())
+		return nil, fmt.Errorf("%s is not a Go file", s.URI())
 	}
-	// If the object is exported, we don't need the full AST to find its definition.
+	// If the object is exported from a different package,
+	// we don't need its full AST to find the definition.
 	var declAST *ast.File
-	if obj.Exported() {
-		declAST = declFile.GetTrimmedAST(ctx)
+	if obj.Exported() && obj.Pkg() != originPkg {
+		declAST = declFile.GetAnyAST(ctx)
 	} else {
 		declAST = declFile.GetAST(ctx)
 	}
diff --git a/internal/lsp/source/signature_help.go b/internal/lsp/source/signature_help.go
index 26d304c..c953e74 100644
--- a/internal/lsp/source/signature_help.go
+++ b/internal/lsp/source/signature_help.go
@@ -91,7 +91,7 @@
 		if err != nil {
 			return nil, err
 		}
-		node, err := objToNode(ctx, f.View(), obj, rng)
+		node, err := objToNode(ctx, f.View(), pkg.GetTypes(), obj, rng)
 		if err != nil {
 			return nil, err
 		}
diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go
index 25d436a..d349e9d 100644
--- a/internal/lsp/source/view.go
+++ b/internal/lsp/source/view.go
@@ -162,9 +162,9 @@
 type GoFile interface {
 	File
 
-	// GetTrimmedAST returns an AST that may or may not contain function bodies.
+	// GetAnyAST returns an AST that may or may not contain function bodies.
 	// It should be used in scenarios where function bodies are not necessary.
-	GetTrimmedAST(ctx context.Context) *ast.File
+	GetAnyAST(ctx context.Context) *ast.File
 
 	// GetAST returns the full AST for the file.
 	GetAST(ctx context.Context) *ast.File